Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3498,12 +3498,16 @@ def noLoadLoop( self, kernel, tensorParametersA, tensorParametersB, isOptNLL, is
module.add(self._wait(kernel, tensorParametersA, tensorParametersB, vlcntVal, -1, -1, "10wait for global read"))
if not kernel["NoLdsWriteCode"]:
module.add(self._wait(kernel, tensorParametersA, tensorParametersB, -1, 0, -1, "4wait for local write"))
module.add(self._syncThreads(kernel, "Wait GR->LW done, sync LDS%u"%self.states.ldsWriteTokenIdx, memoryToken=[self.states.ldsWriteTokenIdx]))

if kernel["enableTDMA"] and kernel["enableTDMB"]:
module.add(self._syncThreads(kernel, "wait for local write done, sync LDS%u"%self.states.ldsBarrierTokenIdx, memoryToken=[self.states.ldsBarrierTokenIdx]))
# swap barrier token, locked by OptNll
if not isOptNLL:
self.states.ldsBarrierTokenIdx = self.states.memTokenLdsBuffer1 if self.states.ldsBarrierTokenIdx == self.states.memTokenLdsBuffer0 else self.states.memTokenLdsBuffer0
elif kernel["enableTDMA"] and kernel["enableTDMB"]:
module.add(self._wait(kernel, tensorParametersA, tensorParametersB, 0, -1, -1, "wait for tensor load to finish"))
module.add(self._syncThreads(kernel))

module.add(self._syncThreads(kernel, "wait for tensor load done, sync LDS%u"%self.states.ldsBarrierTokenIdx, memoryToken=[self.states.ldsBarrierTokenIdx]))
# swap barrier token
if not isOptNLL:
self.states.ldsBarrierTokenIdx = self.states.memTokenLdsBuffer1 if self.states.ldsBarrierTokenIdx == self.states.memTokenLdsBuffer0 else self.states.memTokenLdsBuffer0
# generate no Load Loop Body code
module.add(self.noLoadLoopBody(kernel, tensorParametersA, tensorParametersB, pack, packPre, isOptNLL, isNGLL, NLLfirst, NLLlast, NLLindex=NLLindex, \
NLLnum=NLLnum, useTailloopInNll=useTailloopInNll, remainPgr=remainPgr))
Expand Down Expand Up @@ -3540,13 +3544,14 @@ def _loopBody( self, kernel, tensorParametersA, tensorParametersB, pack, packPre
module.add(self._wait(kernel, tensorParametersA, tensorParametersB, vlcntVal, -1, -1, "11wait for global read"))
if not kernel["NoLdsWriteCode"]:
module.add(self._wait(kernel, tensorParametersA, tensorParametersB, 1, 0, -1, "1wait for local write"))
module.add(self._syncThreads(kernel, "4sync for global read, PGR->LW needs sync LDS0", memoryToken=[self.states.ldsBarrierTokenIdx]))
module.add(self._syncThreads(kernel, "4sync for global read, PGR->LW needs sync LDS%u"%(self.states.ldsBarrierTokenIdx), memoryToken=[self.states.ldsBarrierTokenIdx]))
# swap barrier token
self.states.ldsBarrierTokenIdx = self.states.memTokenLdsBuffer1 if self.states.ldsBarrierTokenIdx == self.states.memTokenLdsBuffer0 else self.states.memTokenLdsBuffer0

if kernel["PrefetchGlobalRead"] and kernel["enableTDMA"] and kernel["enableTDMB"]:
elif kernel["PrefetchGlobalRead"] and kernel["enableTDMA"] and kernel["enableTDMB"]:
module.add(self._wait(kernel, tensorParametersA, tensorParametersB, 0, -1, -1, "wait for tensor load to finish"))
module.add(self._syncThreads(kernel))
module.add(self._syncThreads(kernel, "wait for tensor load to finish, PGR->LW needs sync LDS%u"%(self.states.ldsBarrierTokenIdx), memoryToken=[self.states.ldsBarrierTokenIdx]))
# swap barrier token
self.states.ldsBarrierTokenIdx = self.states.memTokenLdsBuffer1 if self.states.ldsBarrierTokenIdx == self.states.memTokenLdsBuffer0 else self.states.memTokenLdsBuffer0

module.addComment1("Begin Each Unroll: Check VGPR.checkin for INT8 LW")

Expand Down Expand Up @@ -4983,6 +4988,10 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
# loop body code generation
finalLoop = lc == loopCopies - 1
loop.add(self._loopBody( kernel, tensorParametersA, tensorParametersB, pack, packPre, lc, loopCopies, finalLoop, isDTVGRSecondBuf=isDTVGRSecondBuf ))
if self.states.numItersPLR == 0 and not finalLoop:
# swap LDS read buffer
self.states.ldsReadTokenIdx = self.states.memTokenLdsBuffer1 if self.states.ldsReadTokenIdx == self.states.memTokenLdsBuffer0 else self.states.memTokenLdsBuffer0

module.add(loop)

if kernel["ExpertSchedulingMode"] > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6993,7 +6993,7 @@ def closeLoop(self, kernel, tPA, tPB, loopIdx, finalLoop, emitEndLabelOnly=False

if kernel["enableTDMA"] and kernel["enableTDMB"] and not kernel["PrefetchGlobalRead"]:
module.add(SWaitCnt(dscnt=0, comment="TDM PGR=0: wait all ds_reads before TDM overwrite"))
module.add(SBarrier(comment="TDM PGR=0: signal+wait done reading LDS"))
module.add(SBarrier(comment="TDM PGR=0: signal+wait done reading LDS", memoryToken=[self.states.memTokenLdsBuffer0]))

# If PrefetchGlobalRead=1 the loads in the loop prefetch next macro-tile
# For the final trip through the unroll loop we need to ensure those loads stay in bounds.
Expand Down
Loading
Loading