|
1 | 1 | //! GEMM-based general conv paths: im2col + BLAS GEMM, indirect (padding-free) |
2 | 2 | //! convolution, Winograd F(2x2,3x3), and the zero-padded NHWC conv. |
3 | 3 |
|
| 4 | +use crate::core::scope_ctx::par_chunks_mut_dispatch; |
| 5 | + |
4 | 6 | use super::*; |
5 | 7 |
|
6 | 8 | // --------------------------------------------------------------------------- |
@@ -51,54 +53,61 @@ pub fn conv2d_nhwc_indirect_padded( |
51 | 53 | // Zero buffer for padded positions |
52 | 54 | let zero_pixel = vec![0.0f32; c_in]; |
53 | 55 |
|
54 | | - for b in 0..batch { |
55 | | - let batch_in_base = b * in_h * in_w * c_in; |
56 | | - for oy in 0..out_h { |
57 | | - let out_row_start = (b * out_h + oy) * out_row_len; |
58 | | - let out_row = &mut output[out_row_start..out_row_start + out_row_len]; |
| 56 | + let config = ParallelMatmulConfig::default(); |
| 57 | + let process_row = |row_idx: usize, out_row: &mut [f32]| { |
| 58 | + let b = row_idx / out_h; |
| 59 | + let oy = row_idx % out_h; |
59 | 60 |
|
60 | | - for ox in 0..out_w { |
61 | | - let out_cell = &mut out_row[ox * c_out..(ox + 1) * c_out]; |
| 61 | + let batch_in_base = b * in_h * in_w * c_in; |
62 | 62 |
|
63 | | - // Init with bias |
64 | | - if let Some(bv) = bias_data { |
65 | | - out_cell.copy_from_slice(&bv[..c_out]); |
66 | | - } else { |
67 | | - out_cell.fill(0.0); |
68 | | - } |
| 63 | + for ox in 0..out_w { |
| 64 | + let out_cell = &mut out_row[ox * c_out..(ox + 1) * c_out]; |
69 | 65 |
|
70 | | - // Accumulate kernel positions with inline padding check |
71 | | - for ky in 0..kh { |
72 | | - let iy = oy * stride_h + ky; |
73 | | - let in_y = iy as isize - pad_top as isize; |
| 66 | + // Init with bias |
| 67 | + if let Some(bv) = bias_data { |
| 68 | + out_cell.copy_from_slice(&bv[..c_out]); |
| 69 | + } else { |
| 70 | + out_cell.fill(0.0); |
| 71 | + } |
74 | 72 |
|
75 | | - for kx in 0..kw { |
76 | | - let ix = ox * stride_w + kx; |
77 | | - let in_x = ix as isize - pad_left as isize; |
78 | | - |
79 | | - let input_pixel = if in_y >= 0 |
80 | | - && (in_y as usize) < in_h |
81 | | - && in_x >= 0 |
82 | | - && (in_x as usize) < in_w |
83 | | - { |
84 | | - let offset = |
85 | | - batch_in_base + (in_y as usize * in_w + in_x as usize) * c_in; |
86 | | - &in_data[offset..offset + c_in] |
87 | | - } else { |
88 | | - &zero_pixel |
89 | | - }; |
| 73 | + // Accumulate kernel positions with inline padding check |
| 74 | + for ky in 0..kh { |
| 75 | + let iy = oy * stride_h + ky; |
| 76 | + let in_y = iy as isize - pad_top as isize; |
90 | 77 |
|
91 | | - let k_base = (ky * kw + kx) * c_in * c_out; |
92 | | - for ic in 0..c_in { |
93 | | - let iv = input_pixel[ic]; |
94 | | - let kb = k_base + ic * c_out; |
95 | | - conv_fma_row(out_cell, &ker_data[kb..kb + c_out], iv); |
96 | | - } |
| 78 | + for kx in 0..kw { |
| 79 | + let ix = ox * stride_w + kx; |
| 80 | + let in_x = ix as isize - pad_left as isize; |
| 81 | + |
| 82 | + let input_pixel = if in_y >= 0 |
| 83 | + && (in_y as usize) < in_h |
| 84 | + && in_x >= 0 |
| 85 | + && (in_x as usize) < in_w |
| 86 | + { |
| 87 | + let offset = batch_in_base + (in_y as usize * in_w + in_x as usize) * c_in; |
| 88 | + &in_data[offset..offset + c_in] |
| 89 | + } else { |
| 90 | + &zero_pixel |
| 91 | + }; |
| 92 | + |
| 93 | + let k_base = (ky * kw + kx) * c_in * c_out; |
| 94 | + for ic in 0..c_in { |
| 95 | + let iv = input_pixel[ic]; |
| 96 | + let kb = k_base + ic * c_out; |
| 97 | + conv_fma_row(out_cell, &ker_data[kb..kb + c_out], iv); |
97 | 98 | } |
98 | 99 | } |
99 | 100 | } |
| 101 | + } |
100 | 102 |
|
101 | | - apply_conv_activation_inplace(out_row, activation); |
| 103 | + apply_conv_activation_inplace(out_row, activation); |
| 104 | + }; |
| 105 | + |
| 106 | + if should_parallelize_len(output_len, config.min_parallel_output_elements, None) { |
| 107 | + par_chunks_mut_dispatch(&mut output, out_row_len, process_row); |
| 108 | + } else { |
| 109 | + for (row_idx, out_row) in output.chunks_mut(out_row_len).enumerate() { |
| 110 | + process_row(row_idx, out_row); |
102 | 111 | } |
103 | 112 | } |
104 | 113 |
|
|
0 commit comments