@@ -3498,12 +3498,16 @@ def noLoadLoop( self, kernel, tensorParametersA, tensorParametersB, isOptNLL, is
34983498 module .add (self ._wait (kernel , tensorParametersA , tensorParametersB , vlcntVal , - 1 , - 1 , "10wait for global read" ))
34993499 if not kernel ["NoLdsWriteCode" ]:
35003500 module .add (self ._wait (kernel , tensorParametersA , tensorParametersB , - 1 , 0 , - 1 , "4wait for local write" ))
3501- module .add (self ._syncThreads (kernel , "Wait GR->LW done, sync LDS%u" % self .states .ldsWriteTokenIdx , memoryToken = [self .states .ldsWriteTokenIdx ]))
3502-
3503- if kernel ["enableTDMA" ] and kernel ["enableTDMB" ]:
3501+ module .add (self ._syncThreads (kernel , "wait for local write done, sync LDS%u" % self .states .ldsBarrierTokenIdx , memoryToken = [self .states .ldsBarrierTokenIdx ]))
3502+ # swap barrier token, locked by OptNll
3503+ if not isOptNLL :
3504+ self .states .ldsBarrierTokenIdx = self .states .memTokenLdsBuffer1 if self .states .ldsBarrierTokenIdx == self .states .memTokenLdsBuffer0 else self .states .memTokenLdsBuffer0
3505+ elif kernel ["enableTDMA" ] and kernel ["enableTDMB" ]:
35043506 module .add (self ._wait (kernel , tensorParametersA , tensorParametersB , 0 , - 1 , - 1 , "wait for tensor load to finish" ))
3505- module .add (self ._syncThreads (kernel ))
3506-
3507+ module .add (self ._syncThreads (kernel , "wait for tensor load done, sync LDS%u" % self .states .ldsBarrierTokenIdx , memoryToken = [self .states .ldsBarrierTokenIdx ]))
3508+ # swap barrier token
3509+ if not isOptNLL :
3510+ self .states .ldsBarrierTokenIdx = self .states .memTokenLdsBuffer1 if self .states .ldsBarrierTokenIdx == self .states .memTokenLdsBuffer0 else self .states .memTokenLdsBuffer0
35073511 # generate no Load Loop Body code
35083512 module .add (self .noLoadLoopBody (kernel , tensorParametersA , tensorParametersB , pack , packPre , isOptNLL , isNGLL , NLLfirst , NLLlast , NLLindex = NLLindex , \
35093513 NLLnum = NLLnum , useTailloopInNll = useTailloopInNll , remainPgr = remainPgr ))
@@ -3540,13 +3544,14 @@ def _loopBody( self, kernel, tensorParametersA, tensorParametersB, pack, packPre
35403544 module .add (self ._wait (kernel , tensorParametersA , tensorParametersB , vlcntVal , - 1 , - 1 , "11wait for global read" ))
35413545 if not kernel ["NoLdsWriteCode" ]:
35423546 module .add (self ._wait (kernel , tensorParametersA , tensorParametersB , 1 , 0 , - 1 , "1wait for local write" ))
3543- module .add (self ._syncThreads (kernel , "4sync for global read, PGR->LW needs sync LDS0" , memoryToken = [self .states .ldsBarrierTokenIdx ]))
3547+ module .add (self ._syncThreads (kernel , "4sync for global read, PGR->LW needs sync LDS%u" % ( self . states . ldsBarrierTokenIdx ) , memoryToken = [self .states .ldsBarrierTokenIdx ]))
35443548 # swap barrier token
35453549 self .states .ldsBarrierTokenIdx = self .states .memTokenLdsBuffer1 if self .states .ldsBarrierTokenIdx == self .states .memTokenLdsBuffer0 else self .states .memTokenLdsBuffer0
3546-
3547- if kernel ["PrefetchGlobalRead" ] and kernel ["enableTDMA" ] and kernel ["enableTDMB" ]:
3550+ elif kernel ["PrefetchGlobalRead" ] and kernel ["enableTDMA" ] and kernel ["enableTDMB" ]:
35483551 module .add (self ._wait (kernel , tensorParametersA , tensorParametersB , 0 , - 1 , - 1 , "wait for tensor load to finish" ))
3549- module .add (self ._syncThreads (kernel ))
3552+ module .add (self ._syncThreads (kernel , "wait for tensor load to finish, PGR->LW needs sync LDS%u" % (self .states .ldsBarrierTokenIdx ), memoryToken = [self .states .ldsBarrierTokenIdx ]))
3553+ # swap barrier token
3554+ self .states .ldsBarrierTokenIdx = self .states .memTokenLdsBuffer1 if self .states .ldsBarrierTokenIdx == self .states .memTokenLdsBuffer0 else self .states .memTokenLdsBuffer0
35503555
35513556 module .addComment1 ("Begin Each Unroll: Check VGPR.checkin for INT8 LW" )
35523557
@@ -4983,6 +4988,10 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
49834988 # loop body code generation
49844989 finalLoop = lc == loopCopies - 1
49854990 loop .add (self ._loopBody ( kernel , tensorParametersA , tensorParametersB , pack , packPre , lc , loopCopies , finalLoop , isDTVGRSecondBuf = isDTVGRSecondBuf ))
4991+ if self .states .numItersPLR == 0 and not finalLoop :
4992+ # swap LDS read buffer
4993+ self .states .ldsReadTokenIdx = self .states .memTokenLdsBuffer1 if self .states .ldsReadTokenIdx == self .states .memTokenLdsBuffer0 else self .states .memTokenLdsBuffer0
4994+
49864995 module .add (loop )
49874996
49884997 if kernel ["ExpertSchedulingMode" ] > 0 :
0 commit comments