@@ -61,6 +61,33 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
6161
6262 return tilerModel
6363
64+ @staticmethod
65+ def addPolicyConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
66+ # ===== GET NECESSARY INFORMATION =====
67+ # Get I/O buffer names
68+ inputBufferName = parseDict ['data_in' ]
69+
70+ # Get other necessary information
71+ inputShape = parseDict ['data_in_shape' ]
72+ reduceAxes = parseDict ['axes' ]
73+ nonReducedDims = [ax for ax in range (len (inputShape )) if ax not in reduceAxes ]
74+
75+ if len (nonReducedDims ) > 0 :
76+ biggestNonReducedDim = max (nonReducedDims , key = lambda ax : inputShape [ax ])
77+ else :
78+ biggestNonReducedDim = - 1 # No non-reduced dimensions
79+
80+ # ===== ADD CONSTRAINTS =====
81+ # Kernel parallelized only on biggest non-reduced dimension,
82+ # so tile only on that dimension
83+ for ax in range (len (inputShape )):
84+ dimVar = tilerModel .getTensorDimVar (tensorName = inputBufferName , dimIdx = ax )
85+ if ax != biggestNonReducedDim :
86+ # This is not the biggest non-reduced dimension, force no tiling
87+ tilerModel .addConstraint (dimVar == inputShape [ax ])
88+
89+ return tilerModel
90+
6491 @staticmethod
6592 def constructSymbolicNodeRep (tilerModel : TilerModel , parseDict : Dict ,
6693 ctxt : NetworkContext ) -> Dict [str , Union [int , IntVar ]]:
0 commit comments