Skip to content

Commit 8a948ee

Browse files
authored
Explicitly coalesce stores in Slice for smaller output types (#3600)
Signed-off-by: Szymon Karpiński <hugo@staszic.waw.pl>
1 parent e1f563b commit 8a948ee

1 file changed

Lines changed: 37 additions & 12 deletions

File tree

dali/kernels/slice/slice_gpu.cuh

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,28 @@ struct SliceBlockDesc {
7272
uint64_t size;
7373
};
7474

75-
template<typename OutputType>
76-
constexpr int coalesced_values = sizeof(OutputType) >= 4 ? 1 : 4 / sizeof(OutputType);
77-
75+
template<typename T>
76+
union PackedBuffer {
77+
using PackedType = uint32_t;
78+
static constexpr size_t kCapacity = sizeof(T) >= sizeof(PackedType) ?
79+
1 : sizeof(PackedType) / sizeof(T);
80+
81+
T values[kCapacity];
82+
PackedType raw;
83+
84+
__device__ inline void store(T* mem, size_t count) {
85+
if (kCapacity == 1) {
86+
*mem = *values;
87+
} else if (count == kCapacity && reinterpret_cast<uintptr_t>(mem) % sizeof(PackedType) == 0) {
88+
*reinterpret_cast<PackedType*>(mem) = raw;
89+
} else {
90+
#pragma unroll
91+
for (size_t i = 0; i < count; i++) {
92+
mem[i] = values[i];
93+
}
94+
}
95+
}
96+
};
7897

7998
/**
8099
* @brief Simplified algorithm when no padding is necessary
@@ -91,12 +110,14 @@ __device__ void SliceFuncNoPad(OutputType *__restrict__ out, const InputType *__
91110
return;
92111
}
93112

94-
for (; offset < block_end; offset += blockDim.x * coalesced_values<OutputType>) {
113+
for (; offset < block_end; offset += blockDim.x * PackedBuffer<OutputType>::kCapacity) {
114+
PackedBuffer<OutputType> result;
115+
116+
uint64_t i;
95117
#pragma unroll
96-
for (uint64_t i = 0; i < coalesced_values<OutputType>; i++) {
118+
for (i = 0; i < PackedBuffer<OutputType>::kCapacity; i++) {
97119
uint64_t idx = offset + i;
98120
if (idx >= block_end) break;
99-
uint64_t out_idx = idx;
100121
uint64_t in_idx = 0;
101122

102123
#pragma unroll
@@ -105,8 +126,9 @@ __device__ void SliceFuncNoPad(OutputType *__restrict__ out, const InputType *__
105126
in_idx += i_d * in_strides[d];
106127
}
107128
in_idx += idx; // remaining dims have equal strides
108-
out[out_idx] = clamp<OutputType>(in[in_idx]);
129+
result.values[i] = clamp<OutputType>(in[in_idx]);
109130
}
131+
result.store(&out[offset], i);
110132
}
111133
}
112134

@@ -139,14 +161,16 @@ __device__ void SliceFunc(OutputType *__restrict__ out, const InputType *__restr
139161
inner_in_extent = Dims > 1 ? in_strides[LastDim - 1] : in_shape[LastDim] * in_strides[LastDim];
140162
}
141163

142-
for (; offset < block_end; offset += blockDim.x * coalesced_values<OutputType>) {
164+
for (; offset < block_end; offset += blockDim.x * PackedBuffer<OutputType>::kCapacity) {
165+
PackedBuffer<OutputType> result;
166+
167+
uint64_t i;
143168
#ifndef __clang__
144169
#pragma unroll
145170
#endif
146-
for (uint64_t i = 0; i < coalesced_values<OutputType>; i++) {
171+
for (i = 0; i < PackedBuffer<OutputType>::kCapacity; i++) {
147172
uint64_t idx = offset + i;
148173
if (idx >= block_end) break;
149-
uint64_t out_idx = idx;
150174

151175
// If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension,
152176
// because know the strides are 1 (or we treat them as 1 if we fused dimensions)
@@ -175,15 +199,16 @@ __device__ void SliceFunc(OutputType *__restrict__ out, const InputType *__restr
175199
OutputType value = __ldg(&fill_values[i_c]);
176200
if (!out_of_bounds)
177201
value = clamp<OutputType>(in[in_idx]);
178-
out[out_idx] = value;
202+
result.values[i] = value;
179203
}
204+
result.store(&out[offset], i);
180205
}
181206
}
182207

183208
template <typename OutputType, typename InputType, int Dims, bool SupportPad>
184209
__global__ void SliceKernel(const SliceSampleDesc<Dims> *samples, const SliceBlockDesc *blocks) {
185210
int sampleIdx = blocks[blockIdx.x].sampleIdx;
186-
uint64_t offset = blocks[blockIdx.x].offset + threadIdx.x * coalesced_values<OutputType>;
211+
uint64_t offset = blocks[blockIdx.x].offset + threadIdx.x * PackedBuffer<OutputType>::kCapacity;
187212
uint64_t block_end = blocks[blockIdx.x].offset + blocks[blockIdx.x].size;
188213
auto sample = samples[sampleIdx];
189214
auto *out = static_cast<OutputType*>(sample.out);

0 commit comments

Comments
 (0)