Skip to content

Commit dfe59a8

Browse files
committed
fix
1 parent ab2ce10 commit dfe59a8

3 files changed

Lines changed: 83 additions & 53 deletions

File tree

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -948,25 +948,7 @@ fn trans_intrinsic_type<'tcx>(
948948
}
949949

950950
// Generic arg 0: component type T
951-
let component_type = match args.type_at(0).kind() {
952-
TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
953-
TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
954-
TyKind::Int(IntTy::I8) => SpirvType::Integer(8, true).def(span, cx),
955-
TyKind::Int(IntTy::I16) => SpirvType::Integer(16, true).def(span, cx),
956-
TyKind::Int(IntTy::I32) => SpirvType::Integer(32, true).def(span, cx),
957-
TyKind::Int(IntTy::I64) => SpirvType::Integer(64, true).def(span, cx),
958-
TyKind::Uint(UintTy::U8) => SpirvType::Integer(8, false).def(span, cx),
959-
TyKind::Uint(UintTy::U16) => SpirvType::Integer(16, false).def(span, cx),
960-
TyKind::Uint(UintTy::U32) => SpirvType::Integer(32, false).def(span, cx),
961-
TyKind::Uint(UintTy::U64) => SpirvType::Integer(64, false).def(span, cx),
962-
_ => {
963-
return Err(cx.tcx.dcx().span_err(
964-
span,
965-
"unsupported component type for #[spirv(cooperative_matrix)]: \
966-
must be f32, f64, i8, i16, i32, i64, u8, u16, u32, or u64",
967-
));
968-
}
969-
};
951+
let component_type = cx.layout_of(args.type_at(0)).spirv_type(span, cx);
970952

971953
// Const generic 1: USE (MatrixA=0, MatrixB=1, MatrixAccumulator=2)
972954
// Const generic 2: ROWS

crates/spirv-std/src/cooperative_matrix.rs

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,47 @@
55
//! -C target-feature=+CooperativeMatrixKHR,+ext:SPV_KHR_cooperative_matrix
66
//! ```
77
//!
8+
//! See the [SPV_KHR_cooperative_matrix specification] for full details.
9+
//!
10+
//! [SPV_KHR_cooperative_matrix specification]: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html
811
#[cfg(target_arch = "spirv")]
912
use core::arch::asm;
1013
use core::marker::PhantomData;
1114
use core::mem::MaybeUninit;
1215

16+
/// Matrix role in a cooperative multiply-accumulate operation (`D = A × B + C`).
17+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
18+
#[repr(u32)]
19+
pub enum MatrixUse {
20+
/// Input operand A.
21+
MatrixA = 0,
22+
/// Input operand B.
23+
MatrixB = 1,
24+
/// Accumulator / result.
25+
MatrixAccumulator = 2,
26+
}
27+
28+
/// Memory layout for cooperative matrix load/store operations.
29+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
30+
#[repr(u32)]
31+
pub enum MatrixLayout {
32+
/// Rows are stored contiguously.
33+
RowMajor = 0,
34+
/// Columns are stored contiguously.
35+
ColumnMajor = 1,
36+
}
37+
1338
/// Matrix role: input operand A in D = A × B + C.
14-
pub const MATRIX_A: u32 = 0;
39+
pub const MATRIX_A: u32 = MatrixUse::MatrixA as u32;
1540
/// Matrix role: input operand B in D = A × B + C.
16-
pub const MATRIX_B: u32 = 1;
41+
pub const MATRIX_B: u32 = MatrixUse::MatrixB as u32;
1742
/// Matrix role: accumulator / result in D = A × B + C.
18-
pub const MATRIX_ACCUMULATOR: u32 = 2;
43+
pub const MATRIX_ACCUMULATOR: u32 = MatrixUse::MatrixAccumulator as u32;
1944

2045
/// Memory layout: rows are stored contiguously.
21-
pub const ROW_MAJOR: u32 = 0;
46+
pub const ROW_MAJOR: MatrixLayout = MatrixLayout::RowMajor;
2247
/// Memory layout: columns are stored contiguously.
23-
pub const COLUMN_MAJOR: u32 = 1;
48+
pub const COLUMN_MAJOR: MatrixLayout = MatrixLayout::ColumnMajor;
2449

2550
/// A cooperative matrix distributed across the subgroup.
2651
///
@@ -29,7 +54,7 @@ pub const COLUMN_MAJOR: u32 = 1;
2954
///
3055
/// # Type parameters
3156
/// - `T`: element type (`f32`, `f64`, `i32`, `u32`, `i8`, `u8`, etc.)
32-
/// - `USE`: matrix role — one of [`MATRIX_A`], [`MATRIX_B`], [`MATRIX_ACCUMULATOR`]
57+
/// - `USE`: matrix role — one of [`MatrixUse::MatrixA`], [`MatrixUse::MatrixB`], [`MatrixUse::MatrixAccumulator`] cast to `u32`
3358
/// - `ROWS`: number of rows
3459
/// - `COLS`: number of columns
3560
///
@@ -46,22 +71,28 @@ pub struct CooperativeMatrix<T, const USE: u32, const ROWS: u32, const COLS: u32
4671
}
4772

4873
impl<T, const USE: u32, const ROWS: u32, const COLS: u32> CooperativeMatrix<T, USE, ROWS, COLS> {
49-
/// Load a cooperative matrix tile from `ptr` using `layout` and `stride`.
74+
/// Load a cooperative matrix through a pointer.
5075
///
51-
/// - `ptr`: pointer to the first element of the tile in memory
52-
/// - `layout`: [`ROW_MAJOR`] or [`COLUMN_MAJOR`]
53-
/// - `stride`: distance (in elements) between the start of consecutive
54-
/// rows (row-major) or columns (column-major) in the full matrix
76+
/// `slice` must point into an array. `layout` specifies whether the matrix
77+
/// is stored in row-major ([`MatrixLayout::RowMajor`]) or column-major
78+
/// ([`MatrixLayout::ColumnMajor`]) order. `stride` is the number of elements
79+
/// between the start of consecutive rows (row-major) or columns (column-major).
80+
///
81+
/// The scope is always `Subgroup`.
5582
///
5683
/// # Safety
57-
/// `ptr` must be valid for `ROWS * stride` (row-major) or `COLS * stride`
58-
/// (column-major) element reads within the subgroup's access pattern.
84+
/// - `slice` must point into an array and be valid for all element accesses
85+
/// implied by the matrix dimensions, layout, and stride.
86+
/// - All operands must be dynamically uniform within every instance of the
87+
/// subgroup scope.
5988
#[spirv_std_macros::gpu_only]
6089
#[doc(alias = "OpCooperativeMatrixLoadKHR")]
6190
#[inline]
62-
pub unsafe fn load(ptr: *const T, layout: u32, stride: u32) -> Self {
91+
pub unsafe fn load(slice: &[T], layout: MatrixLayout, stride: u32) -> Self {
6392
unsafe {
6493
let mut result = MaybeUninit::<Self>::uninit();
94+
let layout_u32 = layout as u32;
95+
let ptr = slice.as_ptr();
6596
asm!(
6697
"%u32 = OpTypeInt 32 0",
6798
"%layout = OpLoad %u32 {layout}",
@@ -70,24 +101,35 @@ impl<T, const USE: u32, const ROWS: u32, const COLS: u32> CooperativeMatrix<T, U
70101
"%result = OpCooperativeMatrixLoadKHR typeof* {out} {ptr} %layout %stride",
71102
"OpStore {out} %result",
72103
ptr = in(reg) ptr,
73-
layout = in(reg) &layout,
104+
layout = in(reg) &layout_u32,
74105
stride = in(reg) &stride,
75106
out = in(reg) result.as_mut_ptr(),
76107
);
77108
result.assume_init()
78109
}
79110
}
80111

81-
/// Store this cooperative matrix tile to `ptr` using `layout` and `stride`.
112+
/// Store a cooperative matrix through a pointer.
113+
///
114+
/// `slice` must point into an array. `layout` specifies whether the matrix
115+
/// is stored in row-major ([`MatrixLayout::RowMajor`]) or column-major
116+
/// ([`MatrixLayout::ColumnMajor`]) order. `stride` is the number of elements
117+
/// between the start of consecutive rows (row-major) or columns (column-major).
118+
///
119+
/// The scope is always `Subgroup`.
82120
///
83121
/// # Safety
84-
/// `ptr` must be valid for `ROWS * stride` (row-major) or `COLS * stride`
85-
/// (column-major) element writes within the subgroup's access pattern.
122+
/// - `slice` must point into an array and be valid for all element accesses
123+
/// implied by the matrix dimensions, layout, and stride.
124+
/// - All operands must be dynamically uniform within every instance of the
125+
/// subgroup scope.
86126
#[spirv_std_macros::gpu_only]
87127
#[doc(alias = "OpCooperativeMatrixStoreKHR")]
88128
#[inline]
89-
pub unsafe fn store(self, ptr: *mut T, layout: u32, stride: u32) {
129+
pub unsafe fn store(self, slice: &mut [T], layout: MatrixLayout, stride: u32) {
90130
unsafe {
131+
let layout_u32 = layout as u32;
132+
let ptr = slice.as_mut_ptr();
91133
asm!(
92134
"%u32 = OpTypeInt 32 0",
93135
"%layout = OpLoad %u32 {layout}",
@@ -96,7 +138,7 @@ impl<T, const USE: u32, const ROWS: u32, const COLS: u32> CooperativeMatrix<T, U
96138
"OpCooperativeMatrixStoreKHR {ptr} %matrix %layout %stride",
97139
ptr = in(reg) ptr,
98140
matrix = in(reg) &self,
99-
layout = in(reg) &layout,
141+
layout = in(reg) &layout_u32,
100142
stride = in(reg) &stride,
101143
);
102144
}
@@ -125,26 +167,32 @@ impl<T, const USE: u32, const ROWS: u32, const COLS: u32> CooperativeMatrix<T, U
125167
}
126168
}
127169

128-
/// Fused multiply-accumulate: `D = A × B + C`.
170+
/// Linear-algebraic matrix multiply of `A` by `B` and then component-wise add `C`.
171+
///
172+
/// The order of operations is implementation-dependent. All matrices must have the
173+
/// same scope, which is always subgroup here.
129174
///
130-
/// - `A`: `M × K` matrix with role [`MATRIX_A`]
131-
/// - `B`: `K × N` matrix with role [`MATRIX_B`]
132-
/// - `C`: `M × N` matrix with role [`MATRIX_ACCUMULATOR`]
133-
/// - returns: `M × N` accumulator equal to `A × B + C`
175+
/// - `A`: `M × K` matrix with use [`MatrixUse::MatrixA`]
176+
/// - `B`: `K × N` matrix with use [`MatrixUse::MatrixB`]
177+
/// - `C`: `M × N` matrix with use [`MatrixUse::MatrixAccumulator`]
178+
/// - returns `D`: `M × N` accumulator equal to `A × B + C`
134179
///
180+
/// All operands must be dynamically uniform within every instance of the subgroup scope.
135181
///
136182
/// # Capability
137183
/// Requires `CooperativeMatrixKHR` + `SPV_KHR_cooperative_matrix`.
138184
#[spirv_std_macros::gpu_only]
139185
#[doc(alias = "OpCooperativeMatrixMulAddKHR")]
140186
#[inline]
141187
pub fn mul_add<TA, TB, TC, const M: u32, const N: u32, const K: u32>(
142-
a: CooperativeMatrix<TA, MATRIX_A, M, K>,
143-
b: CooperativeMatrix<TB, MATRIX_B, K, N>,
144-
c: CooperativeMatrix<TC, MATRIX_ACCUMULATOR, M, N>,
145-
) -> CooperativeMatrix<TC, MATRIX_ACCUMULATOR, M, N> {
188+
a: CooperativeMatrix<TA, { MatrixUse::MatrixA as u32 }, M, K>,
189+
b: CooperativeMatrix<TB, { MatrixUse::MatrixB as u32 }, K, N>,
190+
c: CooperativeMatrix<TC, { MatrixUse::MatrixAccumulator as u32 }, M, N>,
191+
) -> CooperativeMatrix<TC, { MatrixUse::MatrixAccumulator as u32 }, M, N> {
146192
unsafe {
147-
let mut result = MaybeUninit::<CooperativeMatrix<TC, MATRIX_ACCUMULATOR, M, N>>::uninit();
193+
let mut result = MaybeUninit::<
194+
CooperativeMatrix<TC, { MatrixUse::MatrixAccumulator as u32 }, M, N>,
195+
>::uninit();
148196
asm!(
149197
"%a = OpLoad _ {a}",
150198
"%b = OpLoad _ {b}",

tests/compiletests/ui/arch/cooperative_matrix.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ pub fn main(
1717
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] b: &[f32],
1818
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] c: &mut [f32],
1919
) {
20-
let mat_a = unsafe { MatA::load(a.as_ptr(), ROW_MAJOR, 16) };
21-
let mat_b = unsafe { MatB::load(b.as_ptr(), COLUMN_MAJOR, 16) };
22-
let mat_c = unsafe { MatAcc::load(c.as_ptr(), ROW_MAJOR, 16) };
20+
let mat_a = unsafe { MatA::load(a, ROW_MAJOR, 16) };
21+
let mat_b = unsafe { MatB::load(b, COLUMN_MAJOR, 16) };
22+
let mat_c = unsafe { MatAcc::load(c, ROW_MAJOR, 16) };
2323

2424
let result = mul_add(mat_a, mat_b, mat_c);
2525

26-
unsafe { result.store(c.as_mut_ptr(), ROW_MAJOR, 16) };
26+
unsafe { result.store(c, ROW_MAJOR, 16) };
2727
}

0 commit comments

Comments
 (0)