Skip to content

Commit 9cd97c4

Browse files
committed
parallelize conv2d_nhwc_indirect_padded()
1 parent 2e8b441 commit 9cd97c4

1 file changed

Lines changed: 48 additions & 39 deletions

File tree

crates/yscv-kernels/src/ops/conv/gemm_conv.rs

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
//! GEMM-based general conv paths: im2col + BLAS GEMM, indirect (padding-free)
22
//! convolution, Winograd F(2x2,3x3), and the zero-padded NHWC conv.
33
4+
use crate::core::scope_ctx::par_chunks_mut_dispatch;
5+
46
use super::*;
57

68
// ---------------------------------------------------------------------------
@@ -51,54 +53,61 @@ pub fn conv2d_nhwc_indirect_padded(
5153
// Zero buffer for padded positions
5254
let zero_pixel = vec![0.0f32; c_in];
5355

54-
for b in 0..batch {
55-
let batch_in_base = b * in_h * in_w * c_in;
56-
for oy in 0..out_h {
57-
let out_row_start = (b * out_h + oy) * out_row_len;
58-
let out_row = &mut output[out_row_start..out_row_start + out_row_len];
56+
let config = ParallelMatmulConfig::default();
57+
let process_row = |row_idx: usize, out_row: &mut [f32]| {
58+
let b = row_idx / out_h;
59+
let oy = row_idx % out_h;
5960

60-
for ox in 0..out_w {
61-
let out_cell = &mut out_row[ox * c_out..(ox + 1) * c_out];
61+
let batch_in_base = b * in_h * in_w * c_in;
6262

63-
// Init with bias
64-
if let Some(bv) = bias_data {
65-
out_cell.copy_from_slice(&bv[..c_out]);
66-
} else {
67-
out_cell.fill(0.0);
68-
}
63+
for ox in 0..out_w {
64+
let out_cell = &mut out_row[ox * c_out..(ox + 1) * c_out];
6965

70-
// Accumulate kernel positions with inline padding check
71-
for ky in 0..kh {
72-
let iy = oy * stride_h + ky;
73-
let in_y = iy as isize - pad_top as isize;
66+
// Init with bias
67+
if let Some(bv) = bias_data {
68+
out_cell.copy_from_slice(&bv[..c_out]);
69+
} else {
70+
out_cell.fill(0.0);
71+
}
7472

75-
for kx in 0..kw {
76-
let ix = ox * stride_w + kx;
77-
let in_x = ix as isize - pad_left as isize;
78-
79-
let input_pixel = if in_y >= 0
80-
&& (in_y as usize) < in_h
81-
&& in_x >= 0
82-
&& (in_x as usize) < in_w
83-
{
84-
let offset =
85-
batch_in_base + (in_y as usize * in_w + in_x as usize) * c_in;
86-
&in_data[offset..offset + c_in]
87-
} else {
88-
&zero_pixel
89-
};
73+
// Accumulate kernel positions with inline padding check
74+
for ky in 0..kh {
75+
let iy = oy * stride_h + ky;
76+
let in_y = iy as isize - pad_top as isize;
9077

91-
let k_base = (ky * kw + kx) * c_in * c_out;
92-
for ic in 0..c_in {
93-
let iv = input_pixel[ic];
94-
let kb = k_base + ic * c_out;
95-
conv_fma_row(out_cell, &ker_data[kb..kb + c_out], iv);
96-
}
78+
for kx in 0..kw {
79+
let ix = ox * stride_w + kx;
80+
let in_x = ix as isize - pad_left as isize;
81+
82+
let input_pixel = if in_y >= 0
83+
&& (in_y as usize) < in_h
84+
&& in_x >= 0
85+
&& (in_x as usize) < in_w
86+
{
87+
let offset = batch_in_base + (in_y as usize * in_w + in_x as usize) * c_in;
88+
&in_data[offset..offset + c_in]
89+
} else {
90+
&zero_pixel
91+
};
92+
93+
let k_base = (ky * kw + kx) * c_in * c_out;
94+
for ic in 0..c_in {
95+
let iv = input_pixel[ic];
96+
let kb = k_base + ic * c_out;
97+
conv_fma_row(out_cell, &ker_data[kb..kb + c_out], iv);
9798
}
9899
}
99100
}
101+
}
100102

101-
apply_conv_activation_inplace(out_row, activation);
103+
apply_conv_activation_inplace(out_row, activation);
104+
};
105+
106+
if should_parallelize_len(output_len, config.min_parallel_output_elements, None) {
107+
par_chunks_mut_dispatch(&mut output, out_row_len, process_row);
108+
} else {
109+
for (row_idx, out_row) in output.chunks_mut(out_row_len).enumerate() {
110+
process_row(row_idx, out_row);
102111
}
103112
}
104113

0 commit comments

Comments
 (0)