|
9 | 9 | #pragma once |
10 | 10 |
|
11 | 11 | #include <ATen/ATen.h> |
| 12 | +#include <algorithm> |
12 | 13 | #include <cstdint> |
| 14 | +#include <type_traits> |
13 | 15 |
|
14 | 16 | /// Determine an appropriate CUDA block count along the x axis |
15 | 17 | /// |
16 | 18 | /// When launching CUDA kernels the number of blocks B is often calculated |
17 | 19 | /// w.r.t. the number of threads T and items to be processed N as |
18 | 20 | /// B=(N+T-1)/T - which is integer division rounding up. |
19 | 21 | /// This function abstracts that calculation, performs it in an |
20 | | -/// overflow-safe manner, and limits the return value appropriately. |
| 22 | +/// overflow-safe manner, and limits the return value to the CUDA grid-x |
| 23 | +/// dimension cap (2^31-1 for compute capability >= 3.5). |
21 | 24 | /// |
22 | | -/// This is a general function for all integral data types. |
23 | | -/// The goal of this set of functions is to ensure correct calculations |
24 | | -/// across a variety of data types without forcing the programmer to |
25 | | -/// cast to an appropriate type (which is dangerous because we don't |
26 | | -/// have conversion warnings enabled). The values of the variables |
27 | | -/// can then be checked for correctness at run-time. |
28 | | -/// Specialized functions below handle various combinations of signed |
29 | | -/// and unsigned inputs. This system prevents "pointless comparison |
30 | | -/// against zero" warnings from the compiler for unsigned types |
31 | | -/// (simpler ways of suppressing this warning didn't work) while |
32 | | -/// maintaining the various warnings. |
33 | | -/// |
34 | | -/// Function is designed to facilitate run-time value checking. |
35 | | -template < |
36 | | - typename Integer1, |
37 | | - typename Integer2, |
38 | | - std::enable_if_t<std::is_integral_v<Integer1>, bool> = true, |
39 | | - std::enable_if_t<std::is_integral_v<Integer2>, bool> = true> |
40 | | -constexpr uint32_t cuda_calc_xblock_count_base( |
| 25 | +/// Accepts any pair of integral types. The `if constexpr` branches on |
| 26 | +/// signedness emit the `>= 0` TORCH_CHECKs only for signed types, which |
| 27 | +/// avoids "pointless comparison against zero" warnings on unsigned types |
| 28 | +/// without needing per-signedness SFINAE overloads. |
| 29 | +template <typename Integer1, typename Integer2> |
| 30 | +constexpr uint32_t cuda_calc_xblock_count( |
41 | 31 | Integer1 num_items, |
42 | 32 | Integer2 threads_per_block) { |
43 | | - // The number of threads can be as high as 2048 on some newer architectures, |
44 | | - // but this is not portable. |
| 33 | + static_assert( |
| 34 | + std::is_integral_v<Integer1>, |
| 35 | + "cuda_calc_xblock_count: num_items must be an integral type"); |
| 36 | + static_assert( |
| 37 | + std::is_integral_v<Integer2>, |
| 38 | + "cuda_calc_xblock_count: threads_per_block must be an integral type"); |
| 39 | + |
| 40 | + // The number of threads can be as high as 2048 on some newer |
| 41 | + // architectures, but this is not portable. |
45 | 42 | TORCH_CHECK(threads_per_block <= 1024, "Number of threads must be <=1024!"); |
| 43 | + |
| 44 | + if constexpr (std::is_signed_v<Integer1>) { |
| 45 | + TORCH_CHECK( |
| 46 | + num_items >= 0, |
| 47 | + "When calculating block counts, the number of items must be positive!"); |
| 48 | + } |
| 49 | + if constexpr (std::is_signed_v<Integer2>) { |
| 50 | + TORCH_CHECK( |
| 51 | + threads_per_block >= 0, |
| 52 | + "When calculating thread counts, the number of threads must be positive!"); |
| 53 | + } |
| 54 | + |
46 | 55 | // The CUDA specification at |
47 | 56 | // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications |
48 | 57 | // states that for compute capability 3.5-* the grid dimension of a kernel |
49 | | - // launch must must be <=2^31-1. |
| 58 | + // launch must be <=2^31-1. |
50 | 59 | constexpr uint64_t max_blocks = 2147483647; |
51 | 60 | const auto u_num_items = static_cast<uint64_t>(num_items); |
52 | 61 | const auto u_threads = static_cast<uint64_t>(threads_per_block); |
53 | | - // Overflow safe variant of (a + b - 1) / b |
| 62 | + // Overflow-safe variant of (a + b - 1) / b. |
54 | 63 | const uint64_t blocks = |
55 | 64 | u_num_items / u_threads + (u_num_items % u_threads != 0); |
56 | 65 | return static_cast<uint32_t>(std::min(blocks, max_blocks)); |
57 | 66 | } |
58 | 67 |
|
59 | | -// See: cuda_calc_xblock_count_base |
60 | | -template < |
61 | | - typename Integer1, |
62 | | - typename Integer2, |
63 | | - std::enable_if_t< |
64 | | - std::is_integral_v<Integer1> && std::is_signed_v<Integer2>, |
65 | | - bool> = true, |
66 | | - std::enable_if_t< |
67 | | - std::is_integral_v<Integer2> && std::is_unsigned_v<Integer2>, |
68 | | - bool> = true> |
69 | | -constexpr uint32_t cuda_calc_xblock_count( |
70 | | - Integer1 num_items, |
71 | | - Integer2 threads_per_block) { |
72 | | - TORCH_CHECK( |
73 | | - num_items >= 0, |
74 | | - "When calculating block counts, the number of items must be positive!"); |
75 | | - return cuda_calc_xblock_count_base(num_items, threads_per_block); |
76 | | -} |
77 | | - |
78 | | -// See: cuda_calc_xblock_count_base |
79 | | -template < |
80 | | - typename Integer1, |
81 | | - typename Integer2, |
82 | | - std::enable_if_t< |
83 | | - std::is_integral_v<Integer1> && std::is_unsigned_v<Integer2>, |
84 | | - bool> = true, |
85 | | - std::enable_if_t< |
86 | | - std::is_integral_v<Integer2> && std::is_signed_v<Integer2>, |
87 | | - bool> = true> |
88 | | -constexpr uint32_t cuda_calc_xblock_count( |
89 | | - Integer1 num_items, |
90 | | - Integer2 threads_per_block) { |
91 | | - TORCH_CHECK( |
92 | | - threads_per_block >= 0, |
93 | | - "When calculating thread counts, the number of threads must be positive!"); |
94 | | - return cuda_calc_xblock_count_base(num_items, threads_per_block); |
95 | | -} |
96 | | - |
97 | | -// See: cuda_calc_xblock_count_base |
98 | | -template < |
99 | | - typename Integer1, |
100 | | - typename Integer2, |
101 | | - std::enable_if_t< |
102 | | - std::is_integral_v<Integer1> && std::is_signed_v<Integer2>, |
103 | | - bool> = true, |
104 | | - std::enable_if_t< |
105 | | - std::is_integral_v<Integer2> && std::is_signed_v<Integer2>, |
106 | | - bool> = true> |
107 | | -constexpr uint32_t cuda_calc_xblock_count( |
108 | | - Integer1 num_items, |
109 | | - Integer2 threads_per_block) { |
110 | | - TORCH_CHECK( |
111 | | - num_items >= 0, |
112 | | - "When calculating block counts, the number of items must be positive!"); |
113 | | - TORCH_CHECK( |
114 | | - threads_per_block >= 0, |
115 | | - "When calculating thread counts, the number of threads must be positive!"); |
116 | | - return cuda_calc_xblock_count_base(num_items, threads_per_block); |
117 | | -} |
118 | | - |
119 | | -// See: cuda_calc_xblock_count_base |
120 | | -template < |
121 | | - typename Integer1, |
122 | | - typename Integer2, |
123 | | - std::enable_if_t< |
124 | | - std::is_integral_v<Integer1> && std::is_unsigned_v<Integer2>, |
125 | | - bool> = true, |
126 | | - std::enable_if_t< |
127 | | - std::is_integral_v<Integer2> && std::is_unsigned_v<Integer2>, |
128 | | - bool> = true> |
129 | | -constexpr uint32_t cuda_calc_xblock_count( |
130 | | - Integer1 num_items, |
131 | | - Integer2 threads_per_block) { |
132 | | - return cuda_calc_xblock_count_base(num_items, threads_per_block); |
133 | | -} |
134 | | - |
135 | | -/// Determine an appropriate CUDA block count. |
| 68 | +/// Determine an appropriate CUDA block count for a y- or z-dim of the |
| 69 | +/// launch grid. |
136 | 70 | /// |
137 | | -/// See cuda_calc_xblock_count_base() for details. |
138 | | -template < |
139 | | - typename Integer1, |
140 | | - typename Integer2, |
141 | | - std::enable_if_t<std::is_integral_v<Integer1>, bool> = true, |
142 | | - std::enable_if_t<std::is_integral_v<Integer2>, bool> = true> |
| 71 | +/// The CUDA specification states that the grid dimension of a kernel |
| 72 | +/// launch must generally be <=65535. (For compute capability 3.5-* the |
| 73 | +/// grid's x-dimension may be <=2^31-1; that larger limit is enforced |
| 74 | +/// by `cuda_calc_xblock_count` instead.) Because this function does not |
| 75 | +/// know which dimension is being calculated, it uses the smaller limit. |
| 76 | +/// |
| 77 | +/// See `cuda_calc_xblock_count` for the underlying arithmetic. |
| 78 | +template <typename Integer1, typename Integer2> |
143 | 79 | constexpr uint32_t cuda_calc_block_count( |
144 | 80 | Integer1 num_items, |
145 | 81 | Integer2 threads_per_block) { |
146 | | - // The CUDA specification at |
147 | | - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications |
148 | | - // states that the grid dimension of a kernel launch must generally |
149 | | - // be <=65535. (For compute capability 3.5-* the grid's x-dimension must |
150 | | - // be <=2^31-1.) Because this function does not know which dimension |
151 | | - // is being calculated, we use the smaller limit. |
152 | 82 | constexpr uint32_t max_blocks = 65535; |
153 | 83 | return std::min( |
154 | 84 | cuda_calc_xblock_count(num_items, threads_per_block), max_blocks); |
|
0 commit comments