@@ -193,10 +193,12 @@ class FloatGEMMTileConstraint(TileConstraint):
193193 @staticmethod
194194 def addGeometricalConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
195195
196+ # Get to-be-tiled tensor's buffers
196197 bufferA = ctxt .lookup (name = parseDict ['A' ])
197198 bufferB = ctxt .lookup (name = parseDict ['B' ])
198199 outputBuffer = ctxt .lookup (name = parseDict ['data_out' ])
199200
201+ # Add I/O dimensions to the model as variables
200202 has_bias = 'C' in parseDict and parseDict ['C' ] is not None
201203 bufferC = None
202204 if has_bias :
@@ -222,9 +224,11 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
222224 outputFirstDimVar = tilerModel .getTensorDimVar (tensorName = outputBuffer .name , dimIdx = dimOffsetOut )
223225 outputSecondDimVar = tilerModel .getTensorDimVar (tensorName = outputBuffer .name , dimIdx = dimOffsetOut + 1 )
224226
227+ # Map output dims to inputs dims
225228 tilerModel .addConstraint (outputFirstDimVar == AFirstDimVar )
226229 tilerModel .addConstraint (outputSecondDimVar == BSecondDimVar )
227230
231+ # Add GEMM Geometrical constraints
228232 tilerModel .addConstraint (ASecondDimVar == BFirstDimVar )
229233
230234 # Add bias constraints only if bias is present
@@ -287,7 +291,6 @@ def serializeTilingSolution(
287291 transB = operatorRepresentation ['transB' ]
288292
289293 varA = operatorRepresentation ['A' ]
290- varB = operatorRepresentation ['B' ]
291294
292295 if transA == 0 :
293296 NSize = ctxt .lookup (varA ).shape [- 1 ]
0 commit comments