1717
1818
1919class FloatDivTileConstraint (TileConstraint ):
20- """Tile constraint for FP32 Div operation with ONNX broadcasting support.
21-
22- Supports general NumPy-style broadcasting: both inputs can have any
23- dimension, including scalar, partial broadcasting, and full element-wise.
24- """
20+ """Tile constraint for FP32 Div: supports scalar and element-wise cases."""
2521
2622 dataIn1Name = "A"
2723 dataIn2Name = "B"
@@ -34,41 +30,32 @@ def addGeometricalConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt:
3430 inputBuffer2Name = parseDict [cls .dataIn2Name ]
3531 outputBufferName = parseDict [cls .dataOutName ]
3632
37- input1Shape = list (ctxt .lookup (inputBuffer1Name ).shape )
38- input2Shape = list (ctxt .lookup (inputBuffer2Name ).shape )
39- outputShape = list (ctxt .lookup (outputBufferName ).shape )
40-
41- # Add all tensors to model
4233 tilerModel .addTensorDimToModel (ctxt , inputBuffer1Name )
4334 tilerModel .addTensorDimToModel (ctxt , inputBuffer2Name )
4435 tilerModel .addTensorDimToModel (ctxt , outputBufferName )
4536
46- outNdim = len (outputShape )
47-
48- # Pad input shapes from the left to match output ndim (ONNX broadcasting)
49- padded1 = [1 ] * (outNdim - len (input1Shape )) + input1Shape
50- padded2 = [1 ] * (outNdim - len (input2Shape )) + input2Shape
51-
52- for outDim in range (outNdim ):
53- outputDimVar = tilerModel .getTensorDimVar (tensorName = outputBufferName , dimIdx = outDim )
54-
55- # Input 1: map output dim to actual tensor dim
56- in1ActualDim = outDim - (outNdim - len (input1Shape ))
57- if in1ActualDim >= 0 :
58- in1DimVar = tilerModel .getTensorDimVar (tensorName = inputBuffer1Name , dimIdx = in1ActualDim )
59- if padded1 [outDim ] == 1 :
60- tilerModel .addConstraint (in1DimVar == 1 )
61- else :
62- tilerModel .addConstraint (in1DimVar == outputDimVar )
63-
64- # Input 2: map output dim to actual tensor dim
65- in2ActualDim = outDim - (outNdim - len (input2Shape ))
66- if in2ActualDim >= 0 :
67- in2DimVar = tilerModel .getTensorDimVar (tensorName = inputBuffer2Name , dimIdx = in2ActualDim )
68- if padded2 [outDim ] == 1 :
69- tilerModel .addConstraint (in2DimVar == 1 )
70- else :
71- tilerModel .addConstraint (in2DimVar == outputDimVar )
37+ input1Shape = list (ctxt .lookup (inputBuffer1Name ).shape )
38+ input2Shape = list (ctxt .lookup (inputBuffer2Name ).shape )
39+
40+ is_scalar = (np .prod (input2Shape ) == 1 )
41+
42+ if is_scalar :
43+ # Scalar: tile A and C together, B stays fixed
44+ for dim in range (len (input1Shape )):
45+ in1Var = tilerModel .getTensorDimVar (tensorName = inputBuffer1Name , dimIdx = dim )
46+ outVar = tilerModel .getTensorDimVar (tensorName = outputBufferName , dimIdx = dim )
47+ tilerModel .addConstraint (in1Var == outVar )
48+ for dim in range (len (input2Shape )):
49+ in2Var = tilerModel .getTensorDimVar (tensorName = inputBuffer2Name , dimIdx = dim )
50+ tilerModel .addConstraint (in2Var == input2Shape [dim ])
51+ else :
52+ # Element-wise: all three tensors tiled identically
53+ for dim in range (len (input1Shape )):
54+ in1Var = tilerModel .getTensorDimVar (tensorName = inputBuffer1Name , dimIdx = dim )
55+ in2Var = tilerModel .getTensorDimVar (tensorName = inputBuffer2Name , dimIdx = dim )
56+ outVar = tilerModel .getTensorDimVar (tensorName = outputBufferName , dimIdx = dim )
57+ tilerModel .addConstraint (in1Var == in2Var )
58+ tilerModel .addConstraint (in1Var == outVar )
7259
7360 return tilerModel
7461
@@ -86,38 +73,19 @@ def serializeTilingSolution(
8673 replacements = {"size" : []}
8774 replacementTypes = {"size" : PointerClass (uint16_t )}
8875
89- input1Shape = list (ctxt .lookup (operatorRepresentation [cls .dataIn1Name ]).shape )
9076 input2Shape = list (ctxt .lookup (operatorRepresentation [cls .dataIn2Name ]).shape )
91- outputShape = list (ctxt .lookup (operatorRepresentation [cls .dataOutName ]).shape )
92-
93- outNdim = len (outputShape )
94- padded1 = [1 ] * (outNdim - len (input1Shape )) + input1Shape
95- padded2 = [1 ] * (outNdim - len (input2Shape )) + input2Shape
96-
97- def _deriveInputCube (outputCube , inputShape , paddedShape ):
98- """Derive an input HyperRectangle from an output cube, respecting broadcasting."""
99- offset = []
100- dims = []
101- for outDim in range (outNdim ):
102- actualDim = outDim - (outNdim - len (inputShape ))
103- if actualDim >= 0 :
104- if paddedShape [outDim ] == 1 :
105- offset .append (0 )
106- dims .append (1 )
107- else :
108- offset .append (outputCube .offset [outDim ])
109- dims .append (outputCube .dims [outDim ])
110- return HyperRectangle (tuple (offset ), tuple (dims ))
77+ is_scalar = (np .prod (input2Shape ) == 1 )
11178
11279 inputLoadSchedule = []
11380 outputLoadSchedule = []
11481
11582 for cube in outputCubes :
11683 replacements ["size" ].append (np .prod (cube .dims ))
117-
118- in1Cube = _deriveInputCube (cube , input1Shape , padded1 )
119- in2Cube = _deriveInputCube (cube , input2Shape , padded2 )
120- inputLoadSchedule .append ({cls .dataIn1Name : in1Cube , cls .dataIn2Name : in2Cube })
84+ if is_scalar :
85+ in2Cube = HyperRectangle (tuple ([0 ] * len (input2Shape )), tuple (input2Shape ))
86+ inputLoadSchedule .append ({cls .dataIn1Name : cube , cls .dataIn2Name : in2Cube })
87+ else :
88+ inputLoadSchedule .append ({cls .dataIn1Name : cube , cls .dataIn2Name : cube })
12189
12290 for out in outputCubes :
12391 outputLoadSchedule .append ({cls .dataOutName : out })
0 commit comments