Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ template <typename T, unsigned TPB>
__global__ void EmbedLayerNormKernel(
int hidden_size, const int* input_ids, const int* segment_ids, const T* beta, const T* gamma,
const T* word_embedding, const T* position_embedding, const T* segment_embedding,
const T epsilon, T* output, T* embedding_sum, const int* position_ids, const bool broadcast_position_ids) {
float epsilon, T* output, T* embedding_sum, const int* position_ids, const bool broadcast_position_ids) {
KeyValuePairSum pair_sum;
// 1. lookup word and segment of the block
// blockIdx.x = position in the sequence
Expand All @@ -134,7 +134,7 @@ __global__ void EmbedLayerNormKernel(
__shared__ int segment_id;
__shared__ int position_id;

const T rld = T(1.f / hidden_size);
const float rld = 1.f / hidden_size;
const int sequence_position = blockIdx.y * gridDim.x + blockIdx.x;
if (threadIdx.x == 0) {
word_id = input_ids[sequence_position];
Expand Down Expand Up @@ -162,7 +162,7 @@ __global__ void EmbedLayerNormKernel(
// the output offset is given by b * (sequence_length * hidden_size) + s * hidden_size
const int output_offset = sequence_position * hidden_size;

cub::KeyValuePair<T, T> thread_data(0, 0);
cub::KeyValuePair<float, float> thread_data(0.f, 0.f);

for (int it = threadIdx.x; it < hidden_size; it += TPB) {
const T w(word_embedding[word_offset + it]);
Expand All @@ -177,8 +177,9 @@ __global__ void EmbedLayerNormKernel(
embedding_sum[output_offset + it] = val;
}

const T rldval = rld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));
const float val_f = static_cast<float>(val);
const float rldval = rld * val_f;
thread_data = pair_sum(thread_data, cub::KeyValuePair<float, float>(rldval, rldval * val_f));
}

// 3. layer norm on the sum
Expand All @@ -190,7 +191,7 @@ Status EmbedSkipLayerNorm(
cudaStream_t stream, int hidden_size, int batch_size, int sequence_length,
const int* input_ids, const int* segment_ids, const T* beta, const T* gamma,
const T* word_embedding, const T* position_embedding, const T* segment_embedding,
const T epsilon, T* output, T* embedding_sum, const int* position_ids,
float epsilon, T* output, T* embedding_sum, const int* position_ids,
const bool broadcast_position_ids) {
constexpr int tpb = 256;
const dim3 grid(sequence_length, batch_size, 1);
Expand Down Expand Up @@ -238,7 +239,7 @@ Status LaunchEmbedLayerNormKernel(
stream, hidden_size, batch_size, sequence_length, input_ids, segment_ids,
reinterpret_cast<const half*>(beta), reinterpret_cast<const half*>(gamma),
reinterpret_cast<const half*>(word_embedding), reinterpret_cast<const half*>(position_embedding),
reinterpret_cast<const half*>(segment_embedding), __float2half_rn(epsilon),
reinterpret_cast<const half*>(segment_embedding), epsilon,
reinterpret_cast<half*>(output), reinterpret_cast<half*>(embedding_sum), position_ids,
broadcast_position_ids);
} else {
Expand Down
101 changes: 44 additions & 57 deletions onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,97 +74,80 @@ struct KeyValuePairSum {
const cub::KeyValuePair<float, float>& b) {
return cub::KeyValuePair<float, float>(a.key + b.key, a.value + b.value);
}

__device__ inline cub::KeyValuePair<half, half> operator()(const cub::KeyValuePair<half, half>& a,
const cub::KeyValuePair<half, half>& b) {
const half2 a2 = __halves2half2(a.key, a.value);
const half2 b2 = __halves2half2(b.key, b.value);
const half2 res = AddHalf2(a2, b2);
return cub::KeyValuePair<half, half>(__low2half(res), __high2half(res));
}

__device__ inline cub::KeyValuePair<half2, half2> operator()(const cub::KeyValuePair<half2, half2>& a,
const cub::KeyValuePair<half2, half2>& b) {
return cub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}

__device__ inline cub::KeyValuePair<nv_bfloat16, nv_bfloat16> operator()(const cub::KeyValuePair<nv_bfloat16, nv_bfloat16>& a,
const cub::KeyValuePair<nv_bfloat16, nv_bfloat16>& b) {
const nv_bfloat162 a2 = __halves2bfloat162(a.key, a.value);
const nv_bfloat162 b2 = __halves2bfloat162(b.key, b.value);
const nv_bfloat162 res = AddHalf2(a2, b2);
return cub::KeyValuePair<nv_bfloat16, nv_bfloat16>(__low2bfloat16(res), __high2bfloat16(res));
}
};

template <typename T, int TPB>
__device__ inline void LayerNorm(
const cub::KeyValuePair<T, T>& thread_data, const int ld, const int offset, const T* beta,
const T* gamma, const T epsilon, T* output) {
const cub::KeyValuePair<float, float>& thread_data, const int ld, const int offset, const T* beta,
const T* gamma, const float epsilon, T* output) {
// Assuming thread_data is already divided by ld
// Uses fp32 accumulation for mean/variance to avoid overflow in fp16/bf16.

using BlockReduce = cub::BlockReduce<cub::KeyValuePair<T, T>, TPB>;
using BlockReduce = cub::BlockReduce<cub::KeyValuePair<float, float>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
__shared__ float mu; // mean
__shared__ float rsigma; // 1 / std.dev.

KeyValuePairSum pair_sum;
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);

if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
rsigma = rsqrtf(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();

for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(gamma[i]);
const T b = (nullptr == beta) ? (T)0 : beta[i];
output[idx] = g * (val - mu) * rsigma + b;
const float val = static_cast<float>(output[idx]);
const float g = static_cast<float>(gamma[i]);
const float b = (nullptr == beta) ? 0.f : static_cast<float>(beta[i]);
output[idx] = static_cast<T>(g * (val - mu) * rsigma + b);
}
}

template <typename T, int TPB>
__device__ inline void SimplifiedLayerNorm(
const T& thread_data, const int ld, const int offset, const T* gamma, const T epsilon, T* output) {
const float& thread_data, const int ld, const int offset, const T* gamma, const float epsilon, T* output) {
// Assuming thread_data is already divided by ld
// Uses fp32 accumulation to avoid overflow in fp16/bf16.

using BlockReduce = cub::BlockReduce<T, TPB>;
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T rsigma; // 1 / std.dev.
__shared__ float rsigma; // 1 / std.dev.

const T sum = BlockReduce(temp_storage).Sum(thread_data);
const float sum = BlockReduce(temp_storage).Sum(thread_data);

if (threadIdx.x == 0) {
rsigma = Rsqrt(sum + epsilon);
rsigma = rsqrtf(sum + epsilon);
}
__syncthreads();

for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(gamma[i]);
output[idx] = g * val * rsigma;
const float val = static_cast<float>(output[idx]);
const float g = static_cast<float>(gamma[i]);
output[idx] = static_cast<T>(g * val * rsigma);
}
}

template <typename T, int TPB, int ILP>
__device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair<T, T>& thread_data,
__device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair<float, float>& thread_data,
const int ld, const int idx, const T* beta, const T* gamma,
const T epsilon, T* output) {
const float epsilon, T* output) {
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
// Uses fp32 accumulation for mean/variance to avoid overflow in fp16/bf16.
using VecT = aligned_vector<T, ILP>;
using BlockReduce = cub::BlockReduce<cub::KeyValuePair<T, T>, TPB>;
using BlockReduce = cub::BlockReduce<cub::KeyValuePair<float, float>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
T beta_v[ILP], gamma_v[ILP], output_v[ILP];
__shared__ float mu; // mean
__shared__ float rsigma; // 1 / std.dev.
T gamma_v[ILP], output_v[ILP];

const bool is_valid = ILP * threadIdx.x < ld;
T beta_v[ILP];
if (is_valid) {
if (beta != nullptr) {
VecT* beta_val = reinterpret_cast<VecT*>(&beta_v);
Expand All @@ -176,20 +159,21 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair<
}

KeyValuePairSum pair_sum;
const cub::KeyValuePair<T, T> sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
const cub::KeyValuePair<float, float> sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);

if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
rsigma = rsqrtf(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();

if (is_valid) {
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v[i] = (beta != nullptr)
? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i]
: gamma_v[i] * (input_v[i] - mu) * rsigma;
const float in_f = static_cast<float>(input_v[i]);
const float g_f = static_cast<float>(gamma_v[i]);
const float b_f = (beta != nullptr) ? static_cast<float>(beta_v[i]) : 0.f;
output_v[i] = static_cast<T>(g_f * (in_f - mu) * rsigma + b_f);
}

VecT* output_val = reinterpret_cast<VecT*>(&output_v);
Expand All @@ -198,15 +182,16 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair<
}

template <typename T, int TPB, int ILP>
__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& thread_data, const int ld, const int idx,
const T* gamma, const T epsilon, T* output) {
__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const float& thread_data, const int ld, const int idx,
const T* gamma, const float epsilon, T* output) {
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
// Uses fp32 accumulation to avoid overflow in fp16/bf16.
using VecT = aligned_vector<T, ILP>;
using BlockReduce = cub::BlockReduce<T, TPB>;
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T rsigma; // 1 / std.dev.
__shared__ float rsigma; // 1 / std.dev.

const bool is_valid = ILP * threadIdx.x < ld;

Expand All @@ -217,17 +202,19 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
}

const T sum = BlockReduce(temp_storage).Sum(thread_data);
const float sum = BlockReduce(temp_storage).Sum(thread_data);

if (threadIdx.x == 0) {
rsigma = Rsqrt(sum + epsilon);
rsigma = rsqrtf(sum + epsilon);
}
__syncthreads();

if (is_valid) {
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v[i] = gamma_v[i] * input_v[i] * rsigma;
const float in_f = static_cast<float>(input_v[i]);
const float g_f = static_cast<float>(gamma_v[i]);
output_v[i] = static_cast<T>(g_f * in_f * rsigma);
}

VecT* output_val = reinterpret_cast<VecT*>(&output_v);
Expand Down
55 changes: 20 additions & 35 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,6 @@ namespace contrib {
namespace cuda {

namespace {
template <typename T>
T maybe2half(float x);

template <>
float maybe2half(float x) {
return x;
}

template <>
half maybe2half(float x) {
return __float2half_rn(x);
}

template <>
nv_bfloat16 maybe2half(float x) {
return __float2bfloat16_rn(x);
}

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192};
Expand Down Expand Up @@ -90,15 +72,16 @@ bool CanVectorized(void* output, void* sum_output, const void* input, const void

template <typename T, unsigned TPB, bool Simplified>
__global__ void SkipLayerNormKernel(
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon,
const int ld, int skip_size) {
const T reverse_ld = T(1.f / ld);
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta,
float epsilon, const int ld, int skip_size) {
const float reverse_ld = 1.f / ld;
const int offset = blockIdx.x * ld;
const bool has_bias = (bias != nullptr);

// Reduce sum of x and x^2, and the results are divided by ld.
// Uses fp32 accumulation to avoid overflow in fp16/bf16.
KeyValuePairSum pair_sum;
cub::KeyValuePair<T, T> thread_data(0, 0);
cub::KeyValuePair<float, float> thread_data(0.f, 0.f);

Comment on lines 81 to 85
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
Expand All @@ -109,8 +92,9 @@ __global__ void SkipLayerNormKernel(
}
val += skip[idx % skip_size];

const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));
const float val_f = static_cast<float>(val);
const float rldval = reverse_ld * val_f;
thread_data = pair_sum(thread_data, cub::KeyValuePair<float, float>(rldval, rldval * val_f));

if (sum_output != nullptr) {
sum_output[idx] = val;
Expand All @@ -129,15 +113,15 @@ __global__ void SkipLayerNormKernel(
// Vectorized kernel
template <typename T, unsigned TPB, int ILP, bool Simplified>
__global__ void SkipLayerNormKernelSmall(
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon,
int ld, int skip_size) {
const T rld = T(1.f / ld);
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta,
float epsilon, int ld, int skip_size) {
const float rld = 1.f / ld;
const int idx = blockIdx.x * ld + threadIdx.x * ILP;

using VecT = aligned_vector<T, ILP>;
T sum_v[ILP];

cub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
cub::KeyValuePair<float, float> thread_data(0.f, 0.f);

if (ILP * threadIdx.x < ld) { // load data under this guard to avoid reading out-of-bounds
T skip_v[ILP], bias_v[ILP];
Expand All @@ -155,8 +139,8 @@ __global__ void SkipLayerNormKernelSmall(
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}

T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
float rldval_sum = 0.f;
float rldvalsq_sum = 0.f;
const bool has_sum_output = (sum_output != nullptr);

#pragma unroll
Expand All @@ -166,16 +150,17 @@ __global__ void SkipLayerNormKernelSmall(
}
sum_v[i] += skip_v[i];

const T rldval = rld * sum_v[i];
const float val_f = static_cast<float>(sum_v[i]);
const float rldval = rld * val_f;
rldval_sum += rldval;
rldvalsq_sum += rldval * sum_v[i];
rldvalsq_sum += rldval * val_f;
}

if (has_sum_output) {
*(reinterpret_cast<VecT*>(&sum_output[idx])) = *reinterpret_cast<VecT*>(&sum_v);
}

thread_data = cub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
thread_data = cub::KeyValuePair<float, float>(rldval_sum, rldvalsq_sum);
}

if (Simplified) {
Expand Down Expand Up @@ -203,11 +188,11 @@ void LaunchSkipLayerNormKernel(

#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)
output, sum_output, input, skip, bias, gamma, beta, epsilon, ld, skip_size)

#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \
SkipLayerNormKernel<T, block_size, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)
output, sum_output, input, skip, bias, gamma, beta, epsilon, ld, skip_size)

#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
Expand Down
Loading
Loading