Skip to content

Commit ee7301d

Browse files
committed
update correctness testing for winograd_conv2d_nhwc
1 parent 1dfcf4c commit ee7301d

1 file changed

Lines changed: 89 additions & 0 deletions

File tree

crates/yscv-kernels/src/ops/conv/gemm_conv.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,4 +1159,93 @@ mod tests {
11591159
assert_eq!(scalar_output, vectorized_output);
11601160
}
11611161
}
1162+
1163+
fn assert_winograd_matches_indirect_padded(with_bias: bool, activation: Activation) {
1164+
let batch = 2;
1165+
let in_h = 8;
1166+
let in_w = 9;
1167+
let c_in = 3;
1168+
let c_out = 5;
1169+
let pad_top = 1;
1170+
let pad_left = 1;
1171+
let pad_bottom = 1;
1172+
let pad_right = 1;
1173+
1174+
let input_data: Vec<f32> = (0..batch * in_h * in_w * c_in)
1175+
.map(|i| (i as f32 % 17.0) * 0.05 - 0.4)
1176+
.collect();
1177+
let kernel_data: Vec<f32> = (0..3 * 3 * c_in * c_out)
1178+
.map(|i| (i as f32 % 11.0) * 0.03 - 0.15)
1179+
.collect();
1180+
1181+
let input = Tensor::from_vec(vec![batch, in_h, in_w, c_in], input_data).unwrap();
1182+
let kernel = Tensor::from_vec(vec![3, 3, c_in, c_out], kernel_data).unwrap();
1183+
let bias = with_bias.then(|| {
1184+
let bias_data: Vec<f32> = (0..c_out).map(|i| i as f32 * 0.07 - 0.13).collect();
1185+
Tensor::from_vec(vec![c_out], bias_data).unwrap()
1186+
});
1187+
1188+
let winograd = winograd_conv2d_nhwc(
1189+
input.data(),
1190+
kernel.data(),
1191+
bias.as_ref().map(Tensor::data),
1192+
batch,
1193+
in_h,
1194+
in_w,
1195+
c_in,
1196+
c_out,
1197+
pad_top,
1198+
pad_left,
1199+
pad_bottom,
1200+
pad_right,
1201+
activation,
1202+
)
1203+
.unwrap();
1204+
let indirect = conv2d_nhwc_indirect_padded(
1205+
&input,
1206+
&kernel,
1207+
bias.as_ref(),
1208+
1,
1209+
1,
1210+
pad_top,
1211+
pad_left,
1212+
pad_bottom,
1213+
pad_right,
1214+
activation,
1215+
)
1216+
.unwrap();
1217+
1218+
assert_eq!(winograd.shape(), indirect.shape());
1219+
for (idx, (&actual, &expected)) in winograd.data().iter().zip(indirect.data()).enumerate() {
1220+
let diff = (actual - expected).abs();
1221+
assert!(
1222+
diff <= 1.0e-4,
1223+
"with_bias={with_bias} activation={activation:?} idx={idx} actual={actual} expected={expected} diff={diff}"
1224+
);
1225+
}
1226+
}
1227+
1228+
#[test]
1229+
fn winograd_conv2d_matches_indirect_padded() {
1230+
assert_winograd_matches_indirect_padded(false, Activation::None);
1231+
}
1232+
1233+
#[test]
1234+
fn winograd_conv2d_matches_indirect_padded_bias_only() {
1235+
assert_winograd_matches_indirect_padded(true, Activation::None);
1236+
}
1237+
1238+
#[test]
1239+
fn winograd_conv2d_matches_indirect_padded_activation_only() {
1240+
for activation in [Activation::Relu, Activation::Silu] {
1241+
assert_winograd_matches_indirect_padded(false, activation);
1242+
}
1243+
}
1244+
1245+
#[test]
1246+
fn winograd_conv2d_matches_indirect_padded_with_bias_and_activation() {
1247+
for activation in [Activation::Relu, Activation::Silu] {
1248+
assert_winograd_matches_indirect_padded(true, activation);
1249+
}
1250+
}
11621251
}

0 commit comments

Comments
 (0)