Skip to content

Commit 52d30d6

Browse files
authored
webgpu: Fix wgsl failure from MatMulSplitKProgram (#6838)
Fix that Tint WGSL reader failure: index 3 out of bounds [0..2] from MatMulSplitKProgram.
1 parent 9317f2e commit 52d30d6

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

tfjs-backend-webgpu/src/webgpu_program.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,13 @@ function getOutputCoordsSnippet(
629629
const {x, y = [], z = []} = dispatchLayout;
630630

631631
const outRank = outShape.length;
632+
const rank = x.length + y.length + z.length;
633+
// getOutputCoords is only meaningful when the output rank is same with
634+
// dispatch layout rank.
635+
if (rank !== outRank) {
636+
return '';
637+
}
638+
632639
if (x.length === outRank) {
633640
const dtype = getCoordsDataType(outRank);
634641
const snippet = `fn getOutputCoords() -> ${dtype}{
@@ -642,17 +649,13 @@ function getOutputCoordsSnippet(
642649
let gatherDimensionsStr = '';
643650
const dims = [x, y, z];
644651

645-
let rank = 0;
646-
647652
for (let i = 0; i < dims.length; i++) {
648653
const arr = dims[i];
649654

650655
if (arr.length === 0) {
651656
continue;
652657
}
653658

654-
rank += arr.length;
655-
656659
if (arr.length === 1) {
657660
gatherDimensionsStr += `let d${arr[0]} = i32(globalId[${i}]);`;
658661
} else {

0 commit comments

Comments
 (0)