@@ -43,6 +43,7 @@ static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, Ce
4343 const CeedScalar * restrict t , CeedTransposeMode t_mode , const CeedInt add ,
4444 const CeedScalar * restrict u , CeedScalar * restrict v , const CeedInt JJ , const CeedInt CC ) {
4545 CeedInt t_stride_0 = B , t_stride_1 = 1 ;
46+
4647 if (t_mode == CEED_TRANSPOSE ) {
4748 t_stride_0 = 1 ;
4849 t_stride_1 = J ;
@@ -56,7 +57,6 @@ static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, Ce
5657 for (CeedInt jj = 0 ; jj < JJ ; jj ++ ) {
5758 for (CeedInt cc = 0 ; cc < CC / 4 ; cc ++ ) vv [jj ][cc ] = loadu (& v [(a * J + j + jj ) * C + c + cc * 4 ]);
5859 }
59-
6060 for (CeedInt b = 0 ; b < B ; b ++ ) {
6161 for (CeedInt jj = 0 ; jj < JJ ; jj ++ ) { // unroll
6262 rtype tqv = set1 (t [(j + jj ) * t_stride_0 + b * t_stride_1 ]);
@@ -71,17 +71,19 @@ static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, Ce
7171 }
7272 }
7373 // Remainder of rows
74- CeedInt j = (J / JJ ) * JJ ;
74+ const CeedInt j = (J / JJ ) * JJ ;
75+
7576 if (j < J ) {
7677 for (CeedInt c = 0 ; c < (C / CC ) * CC ; c += CC ) {
7778 rtype vv [JJ ][CC / 4 ]; // Output tile to be held in registers
79+
7880 for (CeedInt jj = 0 ; jj < J - j ; jj ++ ) {
7981 for (CeedInt cc = 0 ; cc < CC / 4 ; cc ++ ) vv [jj ][cc ] = loadu (& v [(a * J + j + jj ) * C + c + cc * 4 ]);
8082 }
81-
8283 for (CeedInt b = 0 ; b < B ; b ++ ) {
8384 for (CeedInt jj = 0 ; jj < J - j ; jj ++ ) { // doesn't unroll
8485 rtype tqv = set1 (t [(j + jj ) * t_stride_0 + b * t_stride_1 ]);
86+
8587 for (CeedInt cc = 0 ; cc < CC / 4 ; cc ++ ) { // unroll
8688 fmadd (vv [jj ][cc ], tqv , loadu (& u [(a * B + b ) * C + c + cc * 4 ]));
8789 }
@@ -103,22 +105,25 @@ static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
103105 const CeedScalar * restrict t , CeedTransposeMode t_mode , const CeedInt add ,
104106 const CeedScalar * restrict u , CeedScalar * restrict v , const CeedInt JJ , const CeedInt CC ) {
105107 CeedInt t_stride_0 = B , t_stride_1 = 1 ;
108+
106109 if (t_mode == CEED_TRANSPOSE ) {
107110 t_stride_0 = 1 ;
108111 t_stride_1 = J ;
109112 }
110113
111- CeedInt J_break = J % JJ ? (J / JJ ) * JJ : (J / JJ - 1 ) * JJ ;
114+ const CeedInt J_break = J % JJ ? (J / JJ ) * JJ : (J / JJ - 1 ) * JJ ;
115+
112116 for (CeedInt a = 0 ; a < A ; a ++ ) {
113117 // Blocks of 4 columns
114118 for (CeedInt c = (C / CC ) * CC ; c < C ; c += 4 ) {
115119 // Blocks of 4 rows
116120 for (CeedInt j = 0 ; j < J_break ; j += JJ ) {
117121 rtype vv [JJ ]; // Output tile to be held in registers
118- for (CeedInt jj = 0 ; jj < JJ ; jj ++ ) vv [jj ] = loadu (& v [(a * J + j + jj ) * C + c ]);
119122
123+ for (CeedInt jj = 0 ; jj < JJ ; jj ++ ) vv [jj ] = loadu (& v [(a * J + j + jj ) * C + c ]);
120124 for (CeedInt b = 0 ; b < B ; b ++ ) {
121125 rtype tqu ;
126+
122127 if (C - c == 1 ) tqu = set (0.0 , 0.0 , 0.0 , u [(a * B + b ) * C + c + 0 ]);
123128 else if (C - c == 2 ) tqu = set (0.0 , 0.0 , u [(a * B + b ) * C + c + 1 ], u [(a * B + b ) * C + c + 0 ]);
124129 else if (C - c == 3 ) tqu = set (0.0 , u [(a * B + b ) * C + c + 2 ], u [(a * B + b ) * C + c + 1 ], u [(a * B + b ) * C + c + 0 ]);
@@ -133,7 +138,8 @@ static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
133138 // Remainder of rows, all columns
134139 for (CeedInt j = J_break ; j < J ; j ++ ) {
135140 for (CeedInt b = 0 ; b < B ; b ++ ) {
136- CeedScalar tq = t [j * t_stride_0 + b * t_stride_1 ];
141+ const CeedScalar tq = t [j * t_stride_0 + b * t_stride_1 ];
142+
137143 for (CeedInt c = (C / CC ) * CC ; c < C ; c ++ ) v [(a * J + j ) * C + c ] += tq * u [(a * B + b ) * C + c ];
138144 }
139145 }
@@ -148,6 +154,7 @@ static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, Cee
148154 CeedTransposeMode t_mode , const CeedInt add , const CeedScalar * restrict u , CeedScalar * restrict v ,
149155 const CeedInt AA , const CeedInt JJ ) {
150156 CeedInt t_stride_0 = B , t_stride_1 = 1 ;
157+
151158 if (t_mode == CEED_TRANSPOSE ) {
152159 t_stride_0 = 1 ;
153160 t_stride_1 = J ;
@@ -157,14 +164,15 @@ static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, Cee
157164 for (CeedInt a = 0 ; a < (A / AA ) * AA ; a += AA ) {
158165 for (CeedInt j = 0 ; j < (J / JJ ) * JJ ; j += JJ ) {
159166 rtype vv [AA ][JJ / 4 ]; // Output tile to be held in registers
167+
160168 for (CeedInt aa = 0 ; aa < AA ; aa ++ ) {
161169 for (CeedInt jj = 0 ; jj < JJ / 4 ; jj ++ ) vv [aa ][jj ] = loadu (& v [(a + aa ) * J + j + jj * 4 ]);
162170 }
163-
164171 for (CeedInt b = 0 ; b < B ; b ++ ) {
165172 for (CeedInt jj = 0 ; jj < JJ / 4 ; jj ++ ) { // unroll
166173 rtype tqv = set (t [(j + jj * 4 + 3 ) * t_stride_0 + b * t_stride_1 ], t [(j + jj * 4 + 2 ) * t_stride_0 + b * t_stride_1 ],
167174 t [(j + jj * 4 + 1 ) * t_stride_0 + b * t_stride_1 ], t [(j + jj * 4 + 0 ) * t_stride_0 + b * t_stride_1 ]);
175+
168176 for (CeedInt aa = 0 ; aa < AA ; aa ++ ) { // unroll
169177 fmadd (vv [aa ][jj ], tqv , set1 (u [(a + aa ) * B + b ]));
170178 }
@@ -176,17 +184,19 @@ static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, Cee
176184 }
177185 }
178186 // Remainder of rows
179- CeedInt a = (A / AA ) * AA ;
187+ const CeedInt a = (A / AA ) * AA ;
188+
180189 for (CeedInt j = 0 ; j < (J / JJ ) * JJ ; j += JJ ) {
181190 rtype vv [AA ][JJ / 4 ]; // Output tile to be held in registers
191+
182192 for (CeedInt aa = 0 ; aa < A - a ; aa ++ ) {
183193 for (CeedInt jj = 0 ; jj < JJ / 4 ; jj ++ ) vv [aa ][jj ] = loadu (& v [(a + aa ) * J + j + jj * 4 ]);
184194 }
185-
186195 for (CeedInt b = 0 ; b < B ; b ++ ) {
187196 for (CeedInt jj = 0 ; jj < JJ / 4 ; jj ++ ) { // unroll
188197 rtype tqv = set (t [(j + jj * 4 + 3 ) * t_stride_0 + b * t_stride_1 ], t [(j + jj * 4 + 2 ) * t_stride_0 + b * t_stride_1 ],
189198 t [(j + jj * 4 + 1 ) * t_stride_0 + b * t_stride_1 ], t [(j + jj * 4 + 0 ) * t_stride_0 + b * t_stride_1 ]);
199+
190200 for (CeedInt aa = 0 ; aa < A - a ; aa ++ ) { // unroll
191201 fmadd (vv [aa ][jj ], tqv , set1 (u [(a + aa ) * B + b ]));
192202 }
@@ -197,16 +207,18 @@ static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, Cee
197207 }
198208 }
199209 // Column remainder
200- CeedInt A_break = A % AA ? (A / AA ) * AA : (A / AA - 1 ) * AA ;
210+ const CeedInt A_break = A % AA ? (A / AA ) * AA : (A / AA - 1 ) * AA ;
211+
201212 // Blocks of 4 columns
202213 for (CeedInt j = (J / JJ ) * JJ ; j < J ; j += 4 ) {
203214 // Blocks of 4 rows
204215 for (CeedInt a = 0 ; a < A_break ; a += AA ) {
205216 rtype vv [AA ]; // Output tile to be held in registers
206- for (CeedInt aa = 0 ; aa < AA ; aa ++ ) vv [aa ] = loadu (& v [(a + aa ) * J + j ]);
207217
218+ for (CeedInt aa = 0 ; aa < AA ; aa ++ ) vv [aa ] = loadu (& v [(a + aa ) * J + j ]);
208219 for (CeedInt b = 0 ; b < B ; b ++ ) {
209220 rtype tqv ;
221+
210222 if (J - j == 1 ) {
211223 tqv = set (0.0 , 0.0 , 0.0 , t [(j + 0 ) * t_stride_0 + b * t_stride_1 ]);
212224 } else if (J - j == 2 ) {
@@ -228,7 +240,8 @@ static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, Cee
228240 // Remainder of rows, all columns
229241 for (CeedInt b = 0 ; b < B ; b ++ ) {
230242 for (CeedInt j = (J / JJ ) * JJ ; j < J ; j ++ ) {
231- CeedScalar tq = t [j * t_stride_0 + b * t_stride_1 ];
243+ const CeedScalar tq = t [j * t_stride_0 + b * t_stride_1 ];
244+
232245 for (CeedInt a = A_break ; a < A ; a ++ ) v [a * J + j ] += tq * u [a * B + b ];
233246 }
234247 }
@@ -271,7 +284,6 @@ static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, C
271284 // Remainder of columns
272285 if (C % blk_size ) CeedTensorContract_Avx_Remainder_8_8 (contract , A , B , C , J , t , t_mode , true, u , v );
273286 }
274-
275287 return CEED_ERROR_SUCCESS ;
276288}
277289
@@ -280,10 +292,9 @@ static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, C
280292//------------------------------------------------------------------------------
281293int CeedTensorContractCreate_Avx (CeedBasis basis , CeedTensorContract contract ) {
282294 Ceed ceed ;
283- CeedCallBackend (CeedTensorContractGetCeed (contract , & ceed ));
284295
296+ CeedCallBackend (CeedTensorContractGetCeed (contract , & ceed ));
285297 CeedCallBackend (CeedSetBackendFunction (ceed , "TensorContract" , contract , "Apply" , CeedTensorContractApply_Avx ));
286-
287298 return CEED_ERROR_SUCCESS ;
288299}
289300
0 commit comments