@@ -555,6 +555,8 @@ fn winograd_conv2d_nhwc(
555555 pad_right : usize ,
556556 activation : Activation ,
557557) -> Result < Tensor , KernelError > {
558+ use rayon:: iter:: { IntoParallelIterator , ParallelIterator } ;
559+
558560 let padded_h = in_h + pad_top + pad_bottom;
559561 let padded_w = in_w + pad_left + pad_right;
560562 let out_h = padded_h - 2 ; // (padded_h - 3) / 1 + 1
@@ -573,6 +575,8 @@ fn winograd_conv2d_nhwc(
573575 let mut output = AlignedVec :: < f32 > :: uninitialized ( batch * out_h * out_w * c_out) ;
574576
575577 for b in 0 ..batch {
578+ use crate :: GemmEpilogue ;
579+
576580 let in_batch = & input[ b * in_h * in_w * c_in..( b + 1 ) * in_h * in_w * c_in] ;
577581
578582 // 2. Input transform: for each tile, for each channel, compute B^T * d * B
@@ -611,11 +615,40 @@ fn winograd_conv2d_nhwc(
611615 // V[alpha]: [n_tiles, c_in], U[alpha]: [c_in, c_out]
612616 // M[alpha]: [n_tiles, c_out]
613617 let mut m_buf = vec ! [ 0.0f32 ; 16 * n_tiles * c_out] ;
618+ let epilogue = GemmEpilogue {
619+ activation : Activation :: None ,
620+ bias : None ,
621+ residual : None ,
622+ } ;
623+ let config = ParallelMatmulConfig :: default ( ) ;
624+
625+ let packed_u: Option < Vec < _ > > =
626+ if should_parallelize_len ( m_buf. len ( ) , config. min_parallel_output_elements , None ) {
627+ Some (
628+ ( 0 ..16 )
629+ . into_par_iter ( )
630+ . map ( |a| {
631+ use crate :: pack_b_for_session;
632+
633+ let u_slice = & u[ a * c_in * c_out..( a + 1 ) * c_in * c_out] ;
634+ pack_b_for_session ( u_slice, c_in, c_out)
635+ } )
636+ . collect ( ) ,
637+ )
638+ } else {
639+ None
640+ } ;
614641 for a in 0 ..16 {
642+ use crate :: matmul_2d_slices_fused_maybe_packed;
643+
615644 let v_slice = & v[ a * n_tiles * c_in..( a + 1 ) * n_tiles * c_in] ;
616645 let u_slice = & u[ a * c_in * c_out..( a + 1 ) * c_in * c_out] ;
617646 let m_slice = & mut m_buf[ a * n_tiles * c_out..( a + 1 ) * n_tiles * c_out] ;
618- super :: super :: matmul:: blas_sgemm ( v_slice, u_slice, m_slice, n_tiles, c_in, c_out) ;
647+ let packed = packed_u. as_ref ( ) . map ( |packed_u| packed_u[ a] . as_ref ( ) ) ;
648+
649+ matmul_2d_slices_fused_maybe_packed (
650+ v_slice, n_tiles, c_in, u_slice, c_out, m_slice, packed, epilogue, config, None ,
651+ ) ;
619652 }
620653
621654 // 4. Output transform: A^T * M * A → 2×2 output per tile, with bias + activation
0 commit comments