@@ -226,16 +226,23 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
226226 # Get to-be-tiled tensor's buffers
227227 bufferA = ctxt .lookup (name = parseDict ['A' ])
228228 bufferB = ctxt .lookup (name = parseDict ['B' ])
229- bufferC = ctxt .lookup (name = parseDict ['C' ])
230229 outputBuffer = ctxt .lookup (name = parseDict ['data_out' ])
230+
231+ has_bias = 'C' in parseDict and parseDict ['C' ] is not None
232+ bufferC = None
233+ if has_bias :
234+ bufferC = ctxt .lookup (name = parseDict ['C' ])
231235
232236 # Add I/O dimensions to the model as variables
233- for bufferName in [bufferA .name , bufferB .name , bufferC .name , outputBuffer .name ]:
237+ buffer_names = [bufferA .name , bufferB .name , outputBuffer .name ]
238+ if has_bias :
239+ buffer_names .append (bufferC .name )
240+
241+ for bufferName in buffer_names :
234242 tilerModel .addTensorDimToModel (ctxt , bufferName )
235243
236244 dimOffsetA = len (bufferA .shape ) - 2
237245 dimOffsetB = len (bufferB .shape ) - 2
238- dimOffsetC = len (bufferC .shape ) - 2
239246 dimOffsetOut = len (outputBuffer .shape ) - 2
240247
241248 AFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name , dimIdx = dimOffsetA + parseDict ['transA' ])
@@ -254,10 +261,13 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
254261 # Add GEMM Geometrical constraints
255262 tilerModel .addConstraint (ASecondDimVar == BFirstDimVar )
256263
257- addDimVar_1 = tilerModel .getTensorDimVar (tensorName = bufferC .name , dimIdx = dimOffsetC )
258- addDimVar_2 = tilerModel .getTensorDimVar (tensorName = bufferC .name , dimIdx = dimOffsetC + 1 )
259- tilerModel .addConstraint (outputFirstDimVar == addDimVar_1 )
260- tilerModel .addConstraint (outputSecondDimVar == addDimVar_2 )
264+ # Add bias constraints only if bias is present
265+ if has_bias :
266+ dimOffsetC = len (bufferC .shape ) - 2
267+ addDimVar_1 = tilerModel .getTensorDimVar (tensorName = bufferC .name , dimIdx = dimOffsetC )
268+ addDimVar_2 = tilerModel .getTensorDimVar (tensorName = bufferC .name , dimIdx = dimOffsetC + 1 )
269+ tilerModel .addConstraint (outputFirstDimVar == addDimVar_1 )
270+ tilerModel .addConstraint (outputSecondDimVar == addDimVar_2 )
261271
262272 return tilerModel
263273
@@ -295,7 +305,14 @@ def serializeTilingSolution(
295305 operatorRepresentation : OperatorRepresentation ) -> Tuple [VariableReplacementScheme , TilingSchedule ]:
296306 outputCubes = [cube .rectangle for cube in absoluteOutputCubes ]
297307
298- addrNames = ['A' , 'B' , 'C' , 'data_out' ]
308+ # Check if C (bias) is present
309+ has_bias = 'C' in operatorRepresentation and operatorRepresentation ['C' ] is not None
310+
311+ # Build address names list based on whether bias is present
312+ addrNames = ['A' , 'B' , 'data_out' ]
313+ if has_bias :
314+ addrNames .insert (2 , 'C' ) # Insert 'C' before 'data_out'
315+
299316 inputBaseOffsets , outputBaseOffsets = cls .extractBaseAddr (tilingSolution , targetMemLevel ,
300317 operatorRepresentation , addrNames )
301318
@@ -350,11 +367,13 @@ def serializeTilingSolution(
350367 else :
351368 BCube = HyperRectangle ((BatchOffset , BOffset , OOffset , NOffset ), (BatchSize , BSize , OSize , NSize ))
352369
353- CCube = HyperRectangle (cube .offset , cube .dims )
354-
355370 inputACubes .append (ACube )
356371 inputBCubes .append (BCube )
357- inputAddCubes .append (CCube )
372+
373+ # Only create C cubes if bias is present
374+ if has_bias :
375+ CCube = HyperRectangle (cube .offset , cube .dims )
376+ inputAddCubes .append (CCube )
358377
359378 inputLoadSchedule = []
360379 outputLoadSchedule = []
@@ -368,8 +387,13 @@ def serializeTilingSolution(
368387 "batch" : PointerClass (uint8_t )
369388 }
370389
371- for a , b , c in zip (inputACubes , inputBCubes , inputAddCubes ):
372- inputLoadSchedule .append ({"A" : a , "B" : b , "C" : c })
390+ # Build input load schedule based on whether bias is present
391+ if has_bias :
392+ for a , b , c in zip (inputACubes , inputBCubes , inputAddCubes ):
393+ inputLoadSchedule .append ({"A" : a , "B" : b , "C" : c })
394+ else :
395+ for a , b in zip (inputACubes , inputBCubes ):
396+ inputLoadSchedule .append ({"A" : a , "B" : b })
373397
374398 for out in outputCubes :
375399 outputLoadSchedule .append ({"data_out" : out })
0 commit comments