|
1 | 1 | #include <gtest/gtest.h> |
| 2 | +#include <vector> |
| 3 | + |
2 | 4 | #include "04_convolution/conv_implicit_gemm.cuh" |
| 5 | +#include "04_convolution/conv_winograd.cuh" |
3 | 6 | #include "common/tensor.cuh" |
4 | 7 | #include "../test_utils.hpp" |
5 | 8 |
|
6 | | -TEST(ConvolutionTest, BasicConv2d) { |
7 | | - int batch = 1, in_c = 3, out_c = 16; |
8 | | - int in_h = 32, in_w = 32; |
9 | | - int k_h = 3, k_w = 3; |
10 | | - int stride = 1, pad = 1; |
11 | | - |
12 | | - int out_h = (in_h + 2 * pad - k_h) / stride + 1; |
13 | | - int out_w = (in_w + 2 * pad - k_w) / stride + 1; |
14 | | - |
15 | | - auto input = hpc::test::random_vector<float>(batch * in_c * in_h * in_w, -1.0f, 1.0f); |
16 | | - auto weight = hpc::test::random_vector<float>(out_c * in_c * k_h * k_w, -1.0f, 1.0f); |
17 | | - |
18 | | - hpc::Tensor<float> d_input(batch * in_c * in_h * in_w); |
19 | | - hpc::Tensor<float> d_weight(out_c * in_c * k_h * k_w); |
20 | | - hpc::Tensor<float> d_output(batch * out_c * out_h * out_w); |
21 | | - |
| 9 | +namespace { |
| 10 | + |
| 11 | +std::vector<float> cpu_conv2d(const std::vector<float>& input, |
| 12 | + const std::vector<float>& weight, |
| 13 | + const hpc::convolution::ConvParams& p) { |
| 14 | + const int out_h = (p.in_height + 2 * p.pad_h - p.dilation_h * (p.kernel_h - 1) - 1) / p.stride_h + 1; |
| 15 | + const int out_w = (p.in_width + 2 * p.pad_w - p.dilation_w * (p.kernel_w - 1) - 1) / p.stride_w + 1; |
| 16 | + std::vector<float> output(p.batch * p.out_channels * out_h * out_w, 0.0f); |
| 17 | + |
| 18 | + for (int b = 0; b < p.batch; ++b) { |
| 19 | + for (int oc = 0; oc < p.out_channels; ++oc) { |
| 20 | + for (int oh = 0; oh < out_h; ++oh) { |
| 21 | + for (int ow = 0; ow < out_w; ++ow) { |
| 22 | + float sum = 0.0f; |
| 23 | + for (int ic = 0; ic < p.in_channels; ++ic) { |
| 24 | + for (int kh = 0; kh < p.kernel_h; ++kh) { |
| 25 | + for (int kw = 0; kw < p.kernel_w; ++kw) { |
| 26 | + const int ih = oh * p.stride_h - p.pad_h + kh * p.dilation_h; |
| 27 | + const int iw = ow * p.stride_w - p.pad_w + kw * p.dilation_w; |
| 28 | + if (ih < 0 || ih >= p.in_height || iw < 0 || iw >= p.in_width) { |
| 29 | + continue; |
| 30 | + } |
| 31 | + const int input_idx = b * (p.in_channels * p.in_height * p.in_width) + |
| 32 | + ic * (p.in_height * p.in_width) + |
| 33 | + ih * p.in_width + iw; |
| 34 | + const int weight_idx = oc * (p.in_channels * p.kernel_h * p.kernel_w) + |
| 35 | + ic * (p.kernel_h * p.kernel_w) + |
| 36 | + kh * p.kernel_w + kw; |
| 37 | + sum += input[input_idx] * weight[weight_idx]; |
| 38 | + } |
| 39 | + } |
| 40 | + } |
| 41 | + const int output_idx = b * (p.out_channels * out_h * out_w) + |
| 42 | + oc * (out_h * out_w) + oh * out_w + ow; |
| 43 | + output[output_idx] = sum; |
| 44 | + } |
| 45 | + } |
| 46 | + } |
| 47 | + } |
| 48 | + |
| 49 | + return output; |
| 50 | +} |
| 51 | + |
| 52 | +} // namespace |
| 53 | + |
| 54 | +TEST(ConvolutionTest, ImplicitGemmMatchesReference) { |
| 55 | + const hpc::convolution::ConvParams params{ |
| 56 | + 1, 2, 3, 5, 5, |
| 57 | + 3, 3, 1, 1, 1, 1, 1, 1, |
| 58 | + }; |
| 59 | + const int out_h = (params.in_height + 2 * params.pad_h - params.dilation_h * (params.kernel_h - 1) - 1) / |
| 60 | + params.stride_h + 1; |
| 61 | + const int out_w = (params.in_width + 2 * params.pad_w - params.dilation_w * (params.kernel_w - 1) - 1) / |
| 62 | + params.stride_w + 1; |
| 63 | + |
| 64 | + const auto input = hpc::test::random_vector<float>( |
| 65 | + params.batch * params.in_channels * params.in_height * params.in_width, -1.0f, 1.0f); |
| 66 | + const auto weight = hpc::test::random_vector<float>( |
| 67 | + params.out_channels * params.in_channels * params.kernel_h * params.kernel_w, -1.0f, 1.0f); |
| 68 | + const auto expected = cpu_conv2d(input, weight, params); |
| 69 | + |
| 70 | + hpc::Tensor<float> d_input(input.size()); |
| 71 | + hpc::Tensor<float> d_weight(weight.size()); |
| 72 | + hpc::Tensor<float> d_output(expected.size()); |
| 73 | + |
22 | 74 | d_input.copy_from_host(input); |
23 | 75 | d_weight.copy_from_host(weight); |
24 | | - |
25 | | - hpc::convolution::ConvParams params{ |
26 | | - batch, in_c, out_c, in_h, in_w, |
27 | | - k_h, k_w, stride, stride, pad, pad, 1, 1 |
28 | | - }; |
29 | | - |
| 76 | + d_output.zero(); |
| 77 | + |
30 | 78 | hpc::convolution::conv2d_implicit_gemm<float>( |
31 | 79 | d_input.data(), d_weight.data(), d_output.data(), params); |
32 | 80 | cudaDeviceSynchronize(); |
33 | | - |
34 | | - auto output = d_output.to_host(); |
35 | | - EXPECT_EQ(output.size(), batch * out_c * out_h * out_w); |
| 81 | + |
| 82 | + const auto output = d_output.to_host(); |
| 83 | + ASSERT_EQ(output.size(), static_cast<size_t>(params.batch * params.out_channels * out_h * out_w)); |
| 84 | + EXPECT_TRUE(hpc::test::vectors_almost_equal(output, expected, 1e-4f, 1e-4f)); |
| 85 | +} |
| 86 | + |
| 87 | +TEST(ConvolutionTest, WinogradPathMatchesImplicitGemmFallback) { |
| 88 | + constexpr int batch = 1; |
| 89 | + constexpr int in_channels = 2; |
| 90 | + constexpr int out_channels = 2; |
| 91 | + constexpr int height = 6; |
| 92 | + constexpr int width = 6; |
| 93 | + constexpr int kernel = 3; |
| 94 | + constexpr int output_size = batch * out_channels * height * width; |
| 95 | + |
| 96 | + const auto input = hpc::test::random_vector<float>(batch * in_channels * height * width, -1.0f, 1.0f); |
| 97 | + const auto weight = hpc::test::random_vector<float>(out_channels * in_channels * kernel * kernel, -1.0f, 1.0f); |
| 98 | + |
| 99 | + hpc::Tensor<float> d_input(input.size()); |
| 100 | + hpc::Tensor<float> d_weight(weight.size()); |
| 101 | + hpc::Tensor<float> d_implicit(output_size); |
| 102 | + hpc::Tensor<float> d_winograd(output_size); |
| 103 | + |
| 104 | + d_input.copy_from_host(input); |
| 105 | + d_weight.copy_from_host(weight); |
| 106 | + d_implicit.zero(); |
| 107 | + d_winograd.zero(); |
| 108 | + |
| 109 | + const hpc::convolution::ConvParams params{ |
| 110 | + batch, in_channels, out_channels, height, width, |
| 111 | + kernel, kernel, 1, 1, 1, 1, 1, 1, |
| 112 | + }; |
| 113 | + |
| 114 | + hpc::convolution::conv2d_implicit_gemm<float>( |
| 115 | + d_input.data(), d_weight.data(), d_implicit.data(), params); |
| 116 | + hpc::convolution::conv2d_winograd<float>( |
| 117 | + d_input.data(), d_weight.data(), d_winograd.data(), batch, in_channels, out_channels, height, width); |
| 118 | + cudaDeviceSynchronize(); |
| 119 | + |
| 120 | + const auto implicit_output = d_implicit.to_host(); |
| 121 | + const auto winograd_output = d_winograd.to_host(); |
| 122 | + EXPECT_TRUE(hpc::test::vectors_almost_equal(winograd_output, implicit_output, 1e-5f, 1e-5f)); |
36 | 123 | } |
0 commit comments