@@ -1159,4 +1159,93 @@ mod tests {
11591159 assert_eq ! ( scalar_output, vectorized_output) ;
11601160 }
11611161 }
1162+
1163+ fn assert_winograd_matches_indirect_padded ( with_bias : bool , activation : Activation ) {
1164+ let batch = 2 ;
1165+ let in_h = 8 ;
1166+ let in_w = 9 ;
1167+ let c_in = 3 ;
1168+ let c_out = 5 ;
1169+ let pad_top = 1 ;
1170+ let pad_left = 1 ;
1171+ let pad_bottom = 1 ;
1172+ let pad_right = 1 ;
1173+
1174+ let input_data: Vec < f32 > = ( 0 ..batch * in_h * in_w * c_in)
1175+ . map ( |i| ( i as f32 % 17.0 ) * 0.05 - 0.4 )
1176+ . collect ( ) ;
1177+ let kernel_data: Vec < f32 > = ( 0 ..3 * 3 * c_in * c_out)
1178+ . map ( |i| ( i as f32 % 11.0 ) * 0.03 - 0.15 )
1179+ . collect ( ) ;
1180+
1181+ let input = Tensor :: from_vec ( vec ! [ batch, in_h, in_w, c_in] , input_data) . unwrap ( ) ;
1182+ let kernel = Tensor :: from_vec ( vec ! [ 3 , 3 , c_in, c_out] , kernel_data) . unwrap ( ) ;
1183+ let bias = with_bias. then ( || {
1184+ let bias_data: Vec < f32 > = ( 0 ..c_out) . map ( |i| i as f32 * 0.07 - 0.13 ) . collect ( ) ;
1185+ Tensor :: from_vec ( vec ! [ c_out] , bias_data) . unwrap ( )
1186+ } ) ;
1187+
1188+ let winograd = winograd_conv2d_nhwc (
1189+ input. data ( ) ,
1190+ kernel. data ( ) ,
1191+ bias. as_ref ( ) . map ( Tensor :: data) ,
1192+ batch,
1193+ in_h,
1194+ in_w,
1195+ c_in,
1196+ c_out,
1197+ pad_top,
1198+ pad_left,
1199+ pad_bottom,
1200+ pad_right,
1201+ activation,
1202+ )
1203+ . unwrap ( ) ;
1204+ let indirect = conv2d_nhwc_indirect_padded (
1205+ & input,
1206+ & kernel,
1207+ bias. as_ref ( ) ,
1208+ 1 ,
1209+ 1 ,
1210+ pad_top,
1211+ pad_left,
1212+ pad_bottom,
1213+ pad_right,
1214+ activation,
1215+ )
1216+ . unwrap ( ) ;
1217+
1218+ assert_eq ! ( winograd. shape( ) , indirect. shape( ) ) ;
1219+ for ( idx, ( & actual, & expected) ) in winograd. data ( ) . iter ( ) . zip ( indirect. data ( ) ) . enumerate ( ) {
1220+ let diff = ( actual - expected) . abs ( ) ;
1221+ assert ! (
1222+ diff <= 1.0e-4 ,
1223+ "with_bias={with_bias} activation={activation:?} idx={idx} actual={actual} expected={expected} diff={diff}"
1224+ ) ;
1225+ }
1226+ }
1227+
1228+ #[ test]
1229+ fn winograd_conv2d_matches_indirect_padded ( ) {
1230+ assert_winograd_matches_indirect_padded ( false , Activation :: None ) ;
1231+ }
1232+
1233+ #[ test]
1234+ fn winograd_conv2d_matches_indirect_padded_bias_only ( ) {
1235+ assert_winograd_matches_indirect_padded ( true , Activation :: None ) ;
1236+ }
1237+
1238+ #[ test]
1239+ fn winograd_conv2d_matches_indirect_padded_activation_only ( ) {
1240+ for activation in [ Activation :: Relu , Activation :: Silu ] {
1241+ assert_winograd_matches_indirect_padded ( false , activation) ;
1242+ }
1243+ }
1244+
1245+ #[ test]
1246+ fn winograd_conv2d_matches_indirect_padded_with_bias_and_activation ( ) {
1247+ for activation in [ Activation :: Relu , Activation :: Silu ] {
1248+ assert_winograd_matches_indirect_padded ( true , activation) ;
1249+ }
1250+ }
11621251}
0 commit comments