55//! -C target-feature=+CooperativeMatrixKHR,+ext:SPV_KHR_cooperative_matrix
66//! ```
77//!
8+ //! See the [SPV_KHR_cooperative_matrix specification] for full details.
9+ //!
10+ //! [SPV_KHR_cooperative_matrix specification]: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html
811#[ cfg( target_arch = "spirv" ) ]
912use core:: arch:: asm;
1013use core:: marker:: PhantomData ;
1114use core:: mem:: MaybeUninit ;
1215
16+ /// Matrix role in a cooperative multiply-accumulate operation (`D = A × B + C`).
17+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
18+ #[ repr( u32 ) ]
19+ pub enum MatrixUse {
20+ /// Input operand A.
21+ MatrixA = 0 ,
22+ /// Input operand B.
23+ MatrixB = 1 ,
24+ /// Accumulator / result.
25+ MatrixAccumulator = 2 ,
26+ }
27+
28+ /// Memory layout for cooperative matrix load/store operations.
29+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
30+ #[ repr( u32 ) ]
31+ pub enum MatrixLayout {
32+ /// Rows are stored contiguously.
33+ RowMajor = 0 ,
34+ /// Columns are stored contiguously.
35+ ColumnMajor = 1 ,
36+ }
37+
1338/// Matrix role: input operand A in D = A × B + C.
14- pub const MATRIX_A : u32 = 0 ;
39+ pub const MATRIX_A : u32 = MatrixUse :: MatrixA as u32 ;
1540/// Matrix role: input operand B in D = A × B + C.
16- pub const MATRIX_B : u32 = 1 ;
41+ pub const MATRIX_B : u32 = MatrixUse :: MatrixB as u32 ;
1742/// Matrix role: accumulator / result in D = A × B + C.
18- pub const MATRIX_ACCUMULATOR : u32 = 2 ;
43+ pub const MATRIX_ACCUMULATOR : u32 = MatrixUse :: MatrixAccumulator as u32 ;
1944
2045/// Memory layout: rows are stored contiguously.
21- pub const ROW_MAJOR : u32 = 0 ;
46+ pub const ROW_MAJOR : MatrixLayout = MatrixLayout :: RowMajor ;
2247/// Memory layout: columns are stored contiguously.
23- pub const COLUMN_MAJOR : u32 = 1 ;
48+ pub const COLUMN_MAJOR : MatrixLayout = MatrixLayout :: ColumnMajor ;
2449
2550/// A cooperative matrix distributed across the subgroup.
2651///
@@ -29,7 +54,7 @@ pub const COLUMN_MAJOR: u32 = 1;
2954///
3055/// # Type parameters
3156/// - `T`: element type (`f32`, `f64`, `i32`, `u32`, `i8`, `u8`, etc.)
32- /// - `USE`: matrix role — one of [`MATRIX_A `], [`MATRIX_B `], [`MATRIX_ACCUMULATOR`]
57+ /// - `USE`: matrix role — one of [`MatrixUse::MatrixA `], [`MatrixUse::MatrixB `], [`MatrixUse::MatrixAccumulator`] cast to `u32`
3358/// - `ROWS`: number of rows
3459/// - `COLS`: number of columns
3560///
@@ -46,22 +71,28 @@ pub struct CooperativeMatrix<T, const USE: u32, const ROWS: u32, const COLS: u32
4671}
4772
4873impl < T , const USE : u32 , const ROWS : u32 , const COLS : u32 > CooperativeMatrix < T , USE , ROWS , COLS > {
49- /// Load a cooperative matrix tile from `ptr` using `layout` and `stride` .
74+ /// Load a cooperative matrix through a pointer .
5075 ///
51- /// - `ptr`: pointer to the first element of the tile in memory
52- /// - `layout`: [`ROW_MAJOR`] or [`COLUMN_MAJOR`]
53- /// - `stride`: distance (in elements) between the start of consecutive
54- /// rows (row-major) or columns (column-major) in the full matrix
76+ /// `slice` must point into an array. `layout` specifies whether the matrix
77+ /// is stored in row-major ([`MatrixLayout::RowMajor`]) or column-major
78+ /// ([`MatrixLayout::ColumnMajor`]) order. `stride` is the number of elements
79+ /// between the start of consecutive rows (row-major) or columns (column-major).
80+ ///
81+ /// The scope is always `Subgroup`.
5582 ///
5683 /// # Safety
57- /// `ptr` must be valid for `ROWS * stride` (row-major) or `COLS * stride`
58- /// (column-major) element reads within the subgroup's access pattern.
84+ /// - `slice` must point into an array and be valid for all element accesses
85+ /// implied by the matrix dimensions, layout, and stride.
86+ /// - All operands must be dynamically uniform within every instance of the
87+ /// subgroup scope.
5988 #[ spirv_std_macros:: gpu_only]
6089 #[ doc( alias = "OpCooperativeMatrixLoadKHR" ) ]
6190 #[ inline]
62- pub unsafe fn load ( ptr : * const T , layout : u32 , stride : u32 ) -> Self {
91+ pub unsafe fn load ( slice : & [ T ] , layout : MatrixLayout , stride : u32 ) -> Self {
6392 unsafe {
6493 let mut result = MaybeUninit :: < Self > :: uninit ( ) ;
94+ let layout_u32 = layout as u32 ;
95+ let ptr = slice. as_ptr ( ) ;
6596 asm ! (
6697 "%u32 = OpTypeInt 32 0" ,
6798 "%layout = OpLoad %u32 {layout}" ,
@@ -70,24 +101,35 @@ impl<T, const USE: u32, const ROWS: u32, const COLS: u32> CooperativeMatrix<T, U
70101 "%result = OpCooperativeMatrixLoadKHR typeof* {out} {ptr} %layout %stride" ,
71102 "OpStore {out} %result" ,
72103 ptr = in( reg) ptr,
73- layout = in( reg) & layout ,
104+ layout = in( reg) & layout_u32 ,
74105 stride = in( reg) & stride,
75106 out = in( reg) result. as_mut_ptr( ) ,
76107 ) ;
77108 result. assume_init ( )
78109 }
79110 }
80111
81- /// Store this cooperative matrix tile to `ptr` using `layout` and `stride`.
112+ /// Store a cooperative matrix through a pointer.
113+ ///
114+ /// `slice` must point into an array. `layout` specifies whether the matrix
115+ /// is stored in row-major ([`MatrixLayout::RowMajor`]) or column-major
116+ /// ([`MatrixLayout::ColumnMajor`]) order. `stride` is the number of elements
117+ /// between the start of consecutive rows (row-major) or columns (column-major).
118+ ///
119+ /// The scope is always `Subgroup`.
82120 ///
83121 /// # Safety
84- /// `ptr` must be valid for `ROWS * stride` (row-major) or `COLS * stride`
85- /// (column-major) element writes within the subgroup's access pattern.
122+ /// - `slice` must point into an array and be valid for all element accesses
123+ /// implied by the matrix dimensions, layout, and stride.
124+ /// - All operands must be dynamically uniform within every instance of the
125+ /// subgroup scope.
86126 #[ spirv_std_macros:: gpu_only]
87127 #[ doc( alias = "OpCooperativeMatrixStoreKHR" ) ]
88128 #[ inline]
89- pub unsafe fn store ( self , ptr : * mut T , layout : u32 , stride : u32 ) {
129+ pub unsafe fn store ( self , slice : & mut [ T ] , layout : MatrixLayout , stride : u32 ) {
90130 unsafe {
131+ let layout_u32 = layout as u32 ;
132+ let ptr = slice. as_mut_ptr ( ) ;
91133 asm ! (
92134 "%u32 = OpTypeInt 32 0" ,
93135 "%layout = OpLoad %u32 {layout}" ,
@@ -96,7 +138,7 @@ impl<T, const USE: u32, const ROWS: u32, const COLS: u32> CooperativeMatrix<T, U
96138 "OpCooperativeMatrixStoreKHR {ptr} %matrix %layout %stride" ,
97139 ptr = in( reg) ptr,
98140 matrix = in( reg) & self ,
99- layout = in( reg) & layout ,
141+ layout = in( reg) & layout_u32 ,
100142 stride = in( reg) & stride,
101143 ) ;
102144 }
@@ -125,26 +167,32 @@ impl<T, const USE: u32, const ROWS: u32, const COLS: u32> CooperativeMatrix<T, U
125167 }
126168}
127169
128- /// Fused multiply-accumulate: `D = A × B + C`.
170+ /// Linear-algebraic matrix multiply of `A` by `B` and then component-wise add `C`.
171+ ///
172+ /// The order of operations is implementation-dependent. All matrices must have the
173+ /// same scope, which is always subgroup here.
129174///
130- /// - `A`: `M × K` matrix with role [`MATRIX_A `]
131- /// - `B`: `K × N` matrix with role [`MATRIX_B `]
132- /// - `C`: `M × N` matrix with role [`MATRIX_ACCUMULATOR `]
133- /// - returns: `M × N` accumulator equal to `A × B + C`
175+ /// - `A`: `M × K` matrix with use [`MatrixUse::MatrixA `]
176+ /// - `B`: `K × N` matrix with use [`MatrixUse::MatrixB `]
177+ /// - `C`: `M × N` matrix with use [`MatrixUse::MatrixAccumulator `]
178+ /// - returns `D` : `M × N` accumulator equal to `A × B + C`
134179///
180+ /// All operands must be dynamically uniform within every instance of the subgroup scope.
135181///
136182/// # Capability
137183/// Requires `CooperativeMatrixKHR` + `SPV_KHR_cooperative_matrix`.
138184#[ spirv_std_macros:: gpu_only]
139185#[ doc( alias = "OpCooperativeMatrixMulAddKHR" ) ]
140186#[ inline]
141187pub fn mul_add < TA , TB , TC , const M : u32 , const N : u32 , const K : u32 > (
142- a : CooperativeMatrix < TA , MATRIX_A , M , K > ,
143- b : CooperativeMatrix < TB , MATRIX_B , K , N > ,
144- c : CooperativeMatrix < TC , MATRIX_ACCUMULATOR , M , N > ,
145- ) -> CooperativeMatrix < TC , MATRIX_ACCUMULATOR , M , N > {
188+ a : CooperativeMatrix < TA , { MatrixUse :: MatrixA as u32 } , M , K > ,
189+ b : CooperativeMatrix < TB , { MatrixUse :: MatrixB as u32 } , K , N > ,
190+ c : CooperativeMatrix < TC , { MatrixUse :: MatrixAccumulator as u32 } , M , N > ,
191+ ) -> CooperativeMatrix < TC , { MatrixUse :: MatrixAccumulator as u32 } , M , N > {
146192 unsafe {
147- let mut result = MaybeUninit :: < CooperativeMatrix < TC , MATRIX_ACCUMULATOR , M , N > > :: uninit ( ) ;
193+ let mut result = MaybeUninit :: <
194+ CooperativeMatrix < TC , { MatrixUse :: MatrixAccumulator as u32 } , M , N > ,
195+ > :: uninit ( ) ;
148196 asm ! (
149197 "%a = OpLoad _ {a}" ,
150198 "%b = OpLoad _ {b}" ,
0 commit comments