Skip to content

Commit eb2aa02

Browse files
committed
Use matmul_2d_slices_fused_maybe_packed
it gives us: * internal parallelization * advantage of `packed_b` usage (when applicable)
1 parent 600c262 commit eb2aa02

1 file changed

Lines changed: 34 additions & 1 deletion

File tree

crates/yscv-kernels/src/ops/conv/gemm_conv.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,8 @@ fn winograd_conv2d_nhwc(
555555
pad_right: usize,
556556
activation: Activation,
557557
) -> Result<Tensor, KernelError> {
558+
use rayon::iter::{IntoParallelIterator, ParallelIterator};
559+
558560
let padded_h = in_h + pad_top + pad_bottom;
559561
let padded_w = in_w + pad_left + pad_right;
560562
let out_h = padded_h - 2; // (padded_h - 3) / 1 + 1
@@ -573,6 +575,8 @@ fn winograd_conv2d_nhwc(
573575
let mut output = AlignedVec::<f32>::uninitialized(batch * out_h * out_w * c_out);
574576

575577
for b in 0..batch {
578+
use crate::GemmEpilogue;
579+
576580
let in_batch = &input[b * in_h * in_w * c_in..(b + 1) * in_h * in_w * c_in];
577581

578582
// 2. Input transform: for each tile, for each channel, compute B^T * d * B
@@ -611,11 +615,40 @@ fn winograd_conv2d_nhwc(
611615
// V[alpha]: [n_tiles, c_in], U[alpha]: [c_in, c_out]
612616
// M[alpha]: [n_tiles, c_out]
613617
let mut m_buf = vec![0.0f32; 16 * n_tiles * c_out];
618+
let epilogue = GemmEpilogue {
619+
activation: Activation::None,
620+
bias: None,
621+
residual: None,
622+
};
623+
let config = ParallelMatmulConfig::default();
624+
625+
let packed_u: Option<Vec<_>> =
626+
if should_parallelize_len(m_buf.len(), config.min_parallel_output_elements, None) {
627+
Some(
628+
(0..16)
629+
.into_par_iter()
630+
.map(|a| {
631+
use crate::pack_b_for_session;
632+
633+
let u_slice = &u[a * c_in * c_out..(a + 1) * c_in * c_out];
634+
pack_b_for_session(u_slice, c_in, c_out)
635+
})
636+
.collect(),
637+
)
638+
} else {
639+
None
640+
};
614641
for a in 0..16 {
642+
use crate::matmul_2d_slices_fused_maybe_packed;
643+
615644
let v_slice = &v[a * n_tiles * c_in..(a + 1) * n_tiles * c_in];
616645
let u_slice = &u[a * c_in * c_out..(a + 1) * c_in * c_out];
617646
let m_slice = &mut m_buf[a * n_tiles * c_out..(a + 1) * n_tiles * c_out];
618-
super::super::matmul::blas_sgemm(v_slice, u_slice, m_slice, n_tiles, c_in, c_out);
647+
let packed = packed_u.as_ref().map(|packed_u| packed_u[a].as_ref());
648+
649+
matmul_2d_slices_fused_maybe_packed(
650+
v_slice, n_tiles, c_in, u_slice, c_out, m_slice, packed, epilogue, config, None,
651+
);
619652
}
620653

621654
// 4. Output transform: A^T * M * A → 2×2 output per tile, with bias + activation

0 commit comments

Comments
 (0)