@@ -72,9 +72,28 @@ struct SliceBlockDesc {
7272 uint64_t size;
7373};
7474
75- template <typename OutputType>
76- constexpr int coalesced_values = sizeof (OutputType) >= 4 ? 1 : 4 / sizeof (OutputType);
77-
75+ template <typename T>
76+ union PackedBuffer {
77+ using PackedType = uint32_t ;
78+ static constexpr size_t kCapacity = sizeof (T) >= sizeof (PackedType) ?
79+ 1 : sizeof (PackedType) / sizeof (T);
80+
81+ T values[kCapacity ];
82+ PackedType raw;
83+
84+ __device__ inline void store (T* mem, size_t count) {
85+ if (kCapacity == 1 ) {
86+ *mem = *values;
87+ } else if (count == kCapacity && reinterpret_cast <uintptr_t >(mem) % sizeof (PackedType) == 0 ) {
88+ *reinterpret_cast <PackedType*>(mem) = raw;
89+ } else {
90+ #pragma unroll
91+ for (size_t i = 0 ; i < count; i++) {
92+ mem[i] = values[i];
93+ }
94+ }
95+ }
96+ };
7897
7998/* *
8099 * @brief Simplified algorithm when no padding is necessary
@@ -91,12 +110,14 @@ __device__ void SliceFuncNoPad(OutputType *__restrict__ out, const InputType *__
91110 return ;
92111 }
93112
94- for (; offset < block_end; offset += blockDim .x * coalesced_values<OutputType>) {
113+ for (; offset < block_end; offset += blockDim .x * PackedBuffer<OutputType>::kCapacity ) {
114+ PackedBuffer<OutputType> result;
115+
116+ uint64_t i;
95117 #pragma unroll
96- for (uint64_t i = 0 ; i < coalesced_values <OutputType>; i++) {
118+ for (i = 0 ; i < PackedBuffer <OutputType>:: kCapacity ; i++) {
97119 uint64_t idx = offset + i;
98120 if (idx >= block_end) break ;
99- uint64_t out_idx = idx;
100121 uint64_t in_idx = 0 ;
101122
102123 #pragma unroll
@@ -105,8 +126,9 @@ __device__ void SliceFuncNoPad(OutputType *__restrict__ out, const InputType *__
105126 in_idx += i_d * in_strides[d];
106127 }
107128 in_idx += idx; // remaining dims have equal strides
108- out[out_idx ] = clamp<OutputType>(in[in_idx]);
129+ result. values [i ] = clamp<OutputType>(in[in_idx]);
109130 }
131+ result.store (&out[offset], i);
110132 }
111133}
112134
@@ -139,14 +161,16 @@ __device__ void SliceFunc(OutputType *__restrict__ out, const InputType *__restr
139161 inner_in_extent = Dims > 1 ? in_strides[LastDim - 1 ] : in_shape[LastDim] * in_strides[LastDim];
140162 }
141163
142- for (; offset < block_end; offset += blockDim .x * coalesced_values<OutputType>) {
164+ for (; offset < block_end; offset += blockDim .x * PackedBuffer<OutputType>::kCapacity ) {
165+ PackedBuffer<OutputType> result;
166+
167+ uint64_t i;
143168 #ifndef __clang__
144169 #pragma unroll
145170 #endif
146- for (uint64_t i = 0 ; i < coalesced_values <OutputType>; i++) {
171+ for (i = 0 ; i < PackedBuffer <OutputType>:: kCapacity ; i++) {
147172 uint64_t idx = offset + i;
148173 if (idx >= block_end) break ;
149- uint64_t out_idx = idx;
150174
151175 // If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension,
152176 // because know the strides are 1 (or we treat them as 1 if we fused dimensions)
@@ -175,15 +199,16 @@ __device__ void SliceFunc(OutputType *__restrict__ out, const InputType *__restr
175199 OutputType value = __ldg (&fill_values[i_c]);
176200 if (!out_of_bounds)
177201 value = clamp<OutputType>(in[in_idx]);
178- out[out_idx ] = value;
202+ result. values [i ] = value;
179203 }
204+ result.store (&out[offset], i);
180205 }
181206}
182207
183208template <typename OutputType, typename InputType, int Dims, bool SupportPad>
184209__global__ void SliceKernel (const SliceSampleDesc<Dims> *samples, const SliceBlockDesc *blocks) {
185210 int sampleIdx = blocks[blockIdx .x ].sampleIdx ;
186- uint64_t offset = blocks[blockIdx .x ].offset + threadIdx .x * coalesced_values <OutputType>;
211+ uint64_t offset = blocks[blockIdx .x ].offset + threadIdx .x * PackedBuffer <OutputType>:: kCapacity ;
187212 uint64_t block_end = blocks[blockIdx .x ].offset + blocks[blockIdx .x ].size ;
188213 auto sample = samples[sampleIdx];
189214 auto *out = static_cast <OutputType*>(sample.out );
0 commit comments