Skip to content

Commit b6b6eb5

Browse files
committed
simplify: remove unused broadcasting logic from FloatDiv/Mul TileConstraints, support scalar and element-wise only
1 parent fc8ea3f commit b6b6eb5

2 files changed

Lines changed: 58 additions & 122 deletions

File tree

Deeploy/Targets/Snitch/TileConstraints/FloatDivTileConstraint.py

Lines changed: 29 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717

1818

1919
class FloatDivTileConstraint(TileConstraint):
20-
"""Tile constraint for FP32 Div operation with ONNX broadcasting support.
21-
22-
Supports general NumPy-style broadcasting: both inputs can have any
23-
dimension, including scalar, partial broadcasting, and full element-wise.
24-
"""
20+
"""Tile constraint for FP32 Div: supports scalar and element-wise cases."""
2521

2622
dataIn1Name = "A"
2723
dataIn2Name = "B"
@@ -34,41 +30,32 @@ def addGeometricalConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt:
3430
inputBuffer2Name = parseDict[cls.dataIn2Name]
3531
outputBufferName = parseDict[cls.dataOutName]
3632

37-
input1Shape = list(ctxt.lookup(inputBuffer1Name).shape)
38-
input2Shape = list(ctxt.lookup(inputBuffer2Name).shape)
39-
outputShape = list(ctxt.lookup(outputBufferName).shape)
40-
41-
# Add all tensors to model
4233
tilerModel.addTensorDimToModel(ctxt, inputBuffer1Name)
4334
tilerModel.addTensorDimToModel(ctxt, inputBuffer2Name)
4435
tilerModel.addTensorDimToModel(ctxt, outputBufferName)
4536

46-
outNdim = len(outputShape)
47-
48-
# Pad input shapes from the left to match output ndim (ONNX broadcasting)
49-
padded1 = [1] * (outNdim - len(input1Shape)) + input1Shape
50-
padded2 = [1] * (outNdim - len(input2Shape)) + input2Shape
51-
52-
for outDim in range(outNdim):
53-
outputDimVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = outDim)
54-
55-
# Input 1: map output dim to actual tensor dim
56-
in1ActualDim = outDim - (outNdim - len(input1Shape))
57-
if in1ActualDim >= 0:
58-
in1DimVar = tilerModel.getTensorDimVar(tensorName = inputBuffer1Name, dimIdx = in1ActualDim)
59-
if padded1[outDim] == 1:
60-
tilerModel.addConstraint(in1DimVar == 1)
61-
else:
62-
tilerModel.addConstraint(in1DimVar == outputDimVar)
63-
64-
# Input 2: map output dim to actual tensor dim
65-
in2ActualDim = outDim - (outNdim - len(input2Shape))
66-
if in2ActualDim >= 0:
67-
in2DimVar = tilerModel.getTensorDimVar(tensorName = inputBuffer2Name, dimIdx = in2ActualDim)
68-
if padded2[outDim] == 1:
69-
tilerModel.addConstraint(in2DimVar == 1)
70-
else:
71-
tilerModel.addConstraint(in2DimVar == outputDimVar)
37+
input1Shape = list(ctxt.lookup(inputBuffer1Name).shape)
38+
input2Shape = list(ctxt.lookup(inputBuffer2Name).shape)
39+
40+
is_scalar = (np.prod(input2Shape) == 1)
41+
42+
if is_scalar:
43+
# Scalar: tile A and C together, B stays fixed
44+
for dim in range(len(input1Shape)):
45+
in1Var = tilerModel.getTensorDimVar(tensorName = inputBuffer1Name, dimIdx = dim)
46+
outVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = dim)
47+
tilerModel.addConstraint(in1Var == outVar)
48+
for dim in range(len(input2Shape)):
49+
in2Var = tilerModel.getTensorDimVar(tensorName = inputBuffer2Name, dimIdx = dim)
50+
tilerModel.addConstraint(in2Var == input2Shape[dim])
51+
else:
52+
# Element-wise: all three tensors tiled identically
53+
for dim in range(len(input1Shape)):
54+
in1Var = tilerModel.getTensorDimVar(tensorName = inputBuffer1Name, dimIdx = dim)
55+
in2Var = tilerModel.getTensorDimVar(tensorName = inputBuffer2Name, dimIdx = dim)
56+
outVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = dim)
57+
tilerModel.addConstraint(in1Var == in2Var)
58+
tilerModel.addConstraint(in1Var == outVar)
7259

7360
return tilerModel
7461

@@ -86,38 +73,19 @@ def serializeTilingSolution(
8673
replacements = {"size": []}
8774
replacementTypes = {"size": PointerClass(uint16_t)}
8875

89-
input1Shape = list(ctxt.lookup(operatorRepresentation[cls.dataIn1Name]).shape)
9076
input2Shape = list(ctxt.lookup(operatorRepresentation[cls.dataIn2Name]).shape)
91-
outputShape = list(ctxt.lookup(operatorRepresentation[cls.dataOutName]).shape)
92-
93-
outNdim = len(outputShape)
94-
padded1 = [1] * (outNdim - len(input1Shape)) + input1Shape
95-
padded2 = [1] * (outNdim - len(input2Shape)) + input2Shape
96-
97-
def _deriveInputCube(outputCube, inputShape, paddedShape):
98-
"""Derive an input HyperRectangle from an output cube, respecting broadcasting."""
99-
offset = []
100-
dims = []
101-
for outDim in range(outNdim):
102-
actualDim = outDim - (outNdim - len(inputShape))
103-
if actualDim >= 0:
104-
if paddedShape[outDim] == 1:
105-
offset.append(0)
106-
dims.append(1)
107-
else:
108-
offset.append(outputCube.offset[outDim])
109-
dims.append(outputCube.dims[outDim])
110-
return HyperRectangle(tuple(offset), tuple(dims))
77+
is_scalar = (np.prod(input2Shape) == 1)
11178

11279
inputLoadSchedule = []
11380
outputLoadSchedule = []
11481

11582
for cube in outputCubes:
11683
replacements["size"].append(np.prod(cube.dims))
117-
118-
in1Cube = _deriveInputCube(cube, input1Shape, padded1)
119-
in2Cube = _deriveInputCube(cube, input2Shape, padded2)
120-
inputLoadSchedule.append({cls.dataIn1Name: in1Cube, cls.dataIn2Name: in2Cube})
84+
if is_scalar:
85+
in2Cube = HyperRectangle(tuple([0] * len(input2Shape)), tuple(input2Shape))
86+
inputLoadSchedule.append({cls.dataIn1Name: cube, cls.dataIn2Name: in2Cube})
87+
else:
88+
inputLoadSchedule.append({cls.dataIn1Name: cube, cls.dataIn2Name: cube})
12189

12290
for out in outputCubes:
12391
outputLoadSchedule.append({cls.dataOutName: out})

Deeploy/Targets/Snitch/TileConstraints/FloatMulTileConstraint.py

Lines changed: 29 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717

1818

1919
class FloatMulTileConstraint(TileConstraint):
20-
"""Tile constraint for FP32 Mul operation with ONNX broadcasting support.
21-
22-
Supports general NumPy-style broadcasting: both inputs can have any
23-
dimension, including scalar, partial broadcasting, and full element-wise.
24-
"""
20+
"""Tile constraint for FP32 Mul: supports scalar and element-wise cases."""
2521

2622
dataIn1Name = "A"
2723
dataIn2Name = "B"
@@ -34,41 +30,32 @@ def addGeometricalConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt:
3430
inputBuffer2Name = parseDict[cls.dataIn2Name]
3531
outputBufferName = parseDict[cls.dataOutName]
3632

37-
input1Shape = list(ctxt.lookup(inputBuffer1Name).shape)
38-
input2Shape = list(ctxt.lookup(inputBuffer2Name).shape)
39-
outputShape = list(ctxt.lookup(outputBufferName).shape)
40-
41-
# Add all tensors to model
4233
tilerModel.addTensorDimToModel(ctxt, inputBuffer1Name)
4334
tilerModel.addTensorDimToModel(ctxt, inputBuffer2Name)
4435
tilerModel.addTensorDimToModel(ctxt, outputBufferName)
4536

46-
outNdim = len(outputShape)
47-
48-
# Pad input shapes from the left to match output ndim (ONNX broadcasting)
49-
padded1 = [1] * (outNdim - len(input1Shape)) + input1Shape
50-
padded2 = [1] * (outNdim - len(input2Shape)) + input2Shape
51-
52-
for outDim in range(outNdim):
53-
outputDimVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = outDim)
54-
55-
# Input 1: map output dim to actual tensor dim
56-
in1ActualDim = outDim - (outNdim - len(input1Shape))
57-
if in1ActualDim >= 0:
58-
in1DimVar = tilerModel.getTensorDimVar(tensorName = inputBuffer1Name, dimIdx = in1ActualDim)
59-
if padded1[outDim] == 1:
60-
tilerModel.addConstraint(in1DimVar == 1)
61-
else:
62-
tilerModel.addConstraint(in1DimVar == outputDimVar)
63-
64-
# Input 2: map output dim to actual tensor dim
65-
in2ActualDim = outDim - (outNdim - len(input2Shape))
66-
if in2ActualDim >= 0:
67-
in2DimVar = tilerModel.getTensorDimVar(tensorName = inputBuffer2Name, dimIdx = in2ActualDim)
68-
if padded2[outDim] == 1:
69-
tilerModel.addConstraint(in2DimVar == 1)
70-
else:
71-
tilerModel.addConstraint(in2DimVar == outputDimVar)
37+
input1Shape = list(ctxt.lookup(inputBuffer1Name).shape)
38+
input2Shape = list(ctxt.lookup(inputBuffer2Name).shape)
39+
40+
is_scalar = (np.prod(input2Shape) == 1)
41+
42+
if is_scalar:
43+
# Scalar: tile A and C together, B stays fixed
44+
for dim in range(len(input1Shape)):
45+
in1Var = tilerModel.getTensorDimVar(tensorName = inputBuffer1Name, dimIdx = dim)
46+
outVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = dim)
47+
tilerModel.addConstraint(in1Var == outVar)
48+
for dim in range(len(input2Shape)):
49+
in2Var = tilerModel.getTensorDimVar(tensorName = inputBuffer2Name, dimIdx = dim)
50+
tilerModel.addConstraint(in2Var == input2Shape[dim])
51+
else:
52+
# Element-wise: all three tensors tiled identically
53+
for dim in range(len(input1Shape)):
54+
in1Var = tilerModel.getTensorDimVar(tensorName = inputBuffer1Name, dimIdx = dim)
55+
in2Var = tilerModel.getTensorDimVar(tensorName = inputBuffer2Name, dimIdx = dim)
56+
outVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = dim)
57+
tilerModel.addConstraint(in1Var == in2Var)
58+
tilerModel.addConstraint(in1Var == outVar)
7259

7360
return tilerModel
7461

@@ -86,38 +73,19 @@ def serializeTilingSolution(
8673
replacements = {"size": []}
8774
replacementTypes = {"size": PointerClass(uint16_t)}
8875

89-
input1Shape = list(ctxt.lookup(operatorRepresentation[cls.dataIn1Name]).shape)
9076
input2Shape = list(ctxt.lookup(operatorRepresentation[cls.dataIn2Name]).shape)
91-
outputShape = list(ctxt.lookup(operatorRepresentation[cls.dataOutName]).shape)
92-
93-
outNdim = len(outputShape)
94-
padded1 = [1] * (outNdim - len(input1Shape)) + input1Shape
95-
padded2 = [1] * (outNdim - len(input2Shape)) + input2Shape
96-
97-
def _deriveInputCube(outputCube, inputShape, paddedShape):
98-
"""Derive an input HyperRectangle from an output cube, respecting broadcasting."""
99-
offset = []
100-
dims = []
101-
for outDim in range(outNdim):
102-
actualDim = outDim - (outNdim - len(inputShape))
103-
if actualDim >= 0:
104-
if paddedShape[outDim] == 1:
105-
offset.append(0)
106-
dims.append(1)
107-
else:
108-
offset.append(outputCube.offset[outDim])
109-
dims.append(outputCube.dims[outDim])
110-
return HyperRectangle(tuple(offset), tuple(dims))
77+
is_scalar = (np.prod(input2Shape) == 1)
11178

11279
inputLoadSchedule = []
11380
outputLoadSchedule = []
11481

11582
for cube in outputCubes:
11683
replacements["size"].append(np.prod(cube.dims))
117-
118-
in1Cube = _deriveInputCube(cube, input1Shape, padded1)
119-
in2Cube = _deriveInputCube(cube, input2Shape, padded2)
120-
inputLoadSchedule.append({cls.dataIn1Name: in1Cube, cls.dataIn2Name: in2Cube})
84+
if is_scalar:
85+
in2Cube = HyperRectangle(tuple([0] * len(input2Shape)), tuple(input2Shape))
86+
inputLoadSchedule.append({cls.dataIn1Name: cube, cls.dataIn2Name: in2Cube})
87+
else:
88+
inputLoadSchedule.append({cls.dataIn1Name: cube, cls.dataIn2Name: cube})
12189

12290
for out in outputCubes:
12391
outputLoadSchedule.append({cls.dataOutName: out})

0 commit comments

Comments
 (0)