Skip to content

Commit 05ac839

Browse files
committed
Fix a missed merge conflict
Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent e53cba5 commit 05ac839

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/cuda/tile/_ir/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2866,9 +2866,9 @@ def mma_scaled_impl(x: Var, x_scale: Var, y: Var, y_scale: Var, acc: Var) -> Var
28662866
raise TileTypeError(f'Expect acc shape to be {output_shape}, got {acc_ty.shape}')
28672867

28682868
# Broadcast scale batch dims to match the broadcasted x/y batch dims
2869-
batch = x_shape.value_types[:-2]
2870-
x_scale_shape = TupleTy(batch + x_scale_ty.shape.value_types[-2:])
2871-
y_scale_shape = TupleTy(batch + y_scale_ty.shape.value_types[-2:])
2869+
batch = x_shape[:-2]
2870+
x_scale_shape = TupleTy(batch + x_scale_ty.shape[-2:])
2871+
y_scale_shape = TupleTy(batch + y_scale_ty.shape[-2:])
28722872

28732873
x = _promote_and_broadcast_to(x, TileTy(x_ty.dtype, x_shape))
28742874
y = _promote_and_broadcast_to(y, TileTy(y_ty.dtype, y_shape))

0 commit comments

Comments
 (0)