2121#include "linear_fp_per_out_channel_params.glslh"
2222#include "linear_fp_weight_tile.glslh"
2323
24+ #if defined(LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT) == defined(LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT)
25+ #define MAYBE_CAST_WVEC4(x) (x)
26+ #else
27+ #define MAYBE_CAST_WVEC4(x) LINEAR_FP_OUTPUT_TILE_VEC4_T(x)
28+ #endif
29+
2430void fp_accumulate_with_fp_weight(
2531 inout FPOutTile accum,
2632 FPInputTile in_tile,
@@ -29,23 +35,23 @@ void fp_accumulate_with_fp_weight(
2935 [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
3036 [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
3137 accum.data[m][n4] =
32- fma(VEC4_T (in_tile.data[m][k4][0]),
33- w_tile.data[mul_4(k4)][n4],
38+ fma(LINEAR_FP_OUTPUT_TILE_VEC4_T (in_tile.data[m][k4][0]),
39+ MAYBE_CAST_WVEC4( w_tile.data[mul_4(k4)][n4]) ,
3440 accum.data[m][n4]);
3541
3642 accum.data[m][n4] =
37- fma(VEC4_T (in_tile.data[m][k4][1]),
38- w_tile.data[mul_4(k4) + 1][n4],
43+ fma(LINEAR_FP_OUTPUT_TILE_VEC4_T (in_tile.data[m][k4][1]),
44+ MAYBE_CAST_WVEC4( w_tile.data[mul_4(k4) + 1][n4]) ,
3945 accum.data[m][n4]);
4046
4147 accum.data[m][n4] =
42- fma(VEC4_T (in_tile.data[m][k4][2]),
43- w_tile.data[mul_4(k4) + 2][n4],
48+ fma(LINEAR_FP_OUTPUT_TILE_VEC4_T (in_tile.data[m][k4][2]),
49+ MAYBE_CAST_WVEC4( w_tile.data[mul_4(k4) + 2][n4]) ,
4450 accum.data[m][n4]);
4551
4652 accum.data[m][n4] =
47- fma(VEC4_T (in_tile.data[m][k4][3]),
48- w_tile.data[mul_4(k4) + 3][n4],
53+ fma(LINEAR_FP_OUTPUT_TILE_VEC4_T (in_tile.data[m][k4][3]),
54+ MAYBE_CAST_WVEC4( w_tile.data[mul_4(k4) + 3][n4]) ,
4955 accum.data[m][n4]);
5056 }
5157 }
0 commit comments