Skip to content

Commit 09faeb0

Browse files
Shuyang Liumeta-codesync[bot]
authored andcommitted
Add SVE-FP16 version of EmbeddingSpMDMNbit (#5728)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2658 Pull Request resolved: #5728 ## Overview Add an SVE/NEON FP16 kernel for N-bit (2-bit/4-bit) quantized sparse embedding lookup on ARM, gated behind the `FBGEMM_TBE_SVE_FP16_ACCUMULATOR=1` runtime flag. The kernel is dispatched only when the output type is fp16 (OutType == uint16_t, !is_bf16_out), where fp16 accumulation precision is acceptable. For `dims ≤ 64 (iters <= 8)`, accumulates entirely in NEON fp16 registers with 2-way unrolling (two embeddings per iteration), then stores directly via `vst1q_f16` — no intermediate buffer, no fp32 widening. `dims < 8 (iters == 0)` fall through to the buf-based path which handles sub-8-element tails correctly. NBit values are unpacked via `unpack_4bit_to_fp16_neon` (SVE zip-based) and `unpack_2bit_to_fp16_neon` (NEON table-based) before the FMA. ## Result ## N-Bit Kernel with FP16 Output: Autovec vs SVE FP16 Register (direct store) **Output type:** fp16 (`-DOUT_TYPE_FLOAT16`) The autovec baseline accumulates in fp16 via a scratch buffer with fp16→fp32→fp16 round-trip for output. The SVE FP16 register path accumulates in NEON fp16 registers and stores directly with `vst1q_f16`. ### 4-Bit Results #### SLS | Dim | Cache | Prefetch | Autovec | FP16 Reg | Speedup | |-----|-------|----------|---------|----------|---------| | 32 | Hot | Off | 2.22 | 5.40 | **2.44x** | | 32 | Hot | On | 2.18 | 5.31 | **2.44x** | | 32 | Cold | Off | 0.67 | 1.01 | **1.51x** | | 32 | Cold | On | 0.72 | 0.98 | **1.37x** | | 48 | Hot | Off | 2.83 | 4.99 | **1.76x** | | 48 | Hot | On | 2.83 | 4.86 | **1.72x** | | 48 | Cold | Off | 1.04 | 1.20 | **1.15x** | | 48 | Cold | On | 1.05 | 1.19 | **1.14x** | | 56 | Hot | Off | 4.06 | 5.14 | **1.27x** | | 56 | Hot | On | 4.06 | 5.16 | **1.27x** | | 56 | Cold | Off | 2.29 | 2.09 | 0.91x | | 56 | Cold | On | 0.85 | 2.08 | **2.45x** | | 64 | Hot | Off | 3.05 | 5.16 | **1.69x** | | 64 | Hot | On | 3.04 | 5.13 | **1.68x** | | 64 | Cold | Off | 1.37 | 1.48 | 1.08x | | 64 | Cold | On | 1.23 | 1.44 | **1.16x** | | 80 | Hot | Off | 2.93 | 3.78 | **1.29x** | | 80 | Hot | On | 2.99 | 3.85 | **1.29x** | | 80 | Cold | Off | 1.34 | 1.63 | **1.22x** | | 80 | Cold | On | 1.16 | 1.61 | **1.38x** | | 96 | Hot | Off | 3.03 | 3.81 | **1.26x** | | 96 | Hot | On | 3.04 | 3.81 | **1.25x** | | 96 | Cold | Off | 1.58 | 2.55 | **1.61x** | | 96 | Cold | On | 1.45 | 1.88 | **1.29x** | | 104 | Hot | Off | 2.79 | 3.67 | **1.31x** | | 104 | Hot | On | 2.79 | 3.64 | **1.31x** | | 104 | Cold | Off | 1.47 | 1.98 | **1.35x** | | 104 | Cold | On | 1.54 | 2.01 | **1.31x** | | 128 | Hot | Off | 3.80 | 4.19 | **1.10x** | | 128 | Hot | On | 3.80 | 4.19 | **1.10x** | | 128 | Cold | Off | 1.82 | 1.64 | 0.90x | | 128 | Cold | On | 1.79 | 1.64 | 0.92x | | 192 | Hot | Off | 3.02 | 3.54 | **1.17x** | | 192 | Hot | On | 3.02 | 3.54 | **1.17x** | | 192 | Cold | Off | 2.59 | 2.79 | 1.07x | | 192 | Cold | On | 2.57 | 2.89 | **1.12x** | | 256 | Hot | Off | 3.05 | 3.57 | **1.17x** | | 256 | Hot | On | 3.06 | 3.58 | **1.17x** | | 256 | Cold | Off | 2.83 | 2.97 | 1.05x | | 256 | Cold | On | 2.79 | 2.93 | 1.05x | | 512 | Hot | Off | 3.04 | 5.53 | **1.82x** | | 512 | Hot | On | 3.02 | 5.52 | **1.83x** | | 512 | Cold | Off | 2.89 | 3.31 | **1.15x** | | 512 | Cold | On | 2.89 | 3.31 | **1.14x** | | 1024 | Hot | Off | 3.01 | 5.18 | **1.72x** | | 1024 | Hot | On | 3.01 | 5.43 | **1.81x** | | 1024 | Cold | Off | 2.84 | 3.69 | **1.30x** | | 1024 | Cold | On | 2.74 | 3.65 | **1.33x** | ### 4-Bit Results #### SLW | Dim | Cache | Prefetch | Autovec | FP16 Reg | Speedup | |-----|-------|----------|---------|----------|---------| | 32 | Hot | Off | 2.03 | 4.41 | **2.17x** | | 32 | Hot | On | 2.04 | 4.52 | **2.22x** | | 32 | Cold | Off | 0.70 | 0.98 | **1.39x** | | 32 | Cold | On | 0.67 | 1.01 | **1.50x** | | 48 | Hot | Off | 2.75 | 4.52 | **1.64x** | | 48 | Hot | On | 2.75 | 4.51 | **1.64x** | | 48 | Cold | Off | 0.97 | 1.18 | **1.22x** | | 48 | Cold | On | 0.97 | 1.30 | **1.34x** | | 56 | Hot | Off | 3.71 | 4.77 | **1.29x** | | 56 | Hot | On | 3.73 | 4.80 | **1.29x** | | 56 | Cold | Off | 2.11 | 2.02 | 0.95x | | 56 | Cold | On | 2.08 | 2.02 | 0.97x | | 64 | Hot | Off | 2.90 | 4.65 | **1.60x** | | 64 | Hot | On | 2.90 | 4.65 | **1.60x** | | 64 | Cold | Off | 1.41 | 1.60 | **1.13x** | | 64 | Cold | On | 1.30 | 1.58 | **1.21x** | | 80 | Hot | Off | 2.87 | 3.69 | **1.29x** | | 80 | Hot | On | 2.87 | 3.70 | **1.29x** | | 80 | Cold | Off | 1.37 | 1.68 | **1.23x** | | 80 | Cold | On | 1.48 | 1.65 | **1.12x** | | 96 | Hot | Off | 2.93 | 3.79 | **1.29x** | | 96 | Hot | On | 2.94 | 3.80 | **1.29x** | | 96 | Cold | Off | 1.74 | 2.39 | **1.37x** | | 96 | Cold | On | 1.41 | 2.32 | **1.65x** | | 104 | Hot | Off | 2.76 | 3.63 | **1.31x** | | 104 | Hot | On | 2.77 | 3.61 | **1.31x** | | 104| Cold | Off | 1.33 | 1.94 | **1.46x** | | 104 | Cold | On | 1.44 | 2.07 | **1.43x** | | 128 | Hot | Off | 3.85 | 4.08 | 1.06x | | 128 | Hot | On | 3.85 | 4.07 | 1.06x | | 128 | Cold | Off | 1.68 | 1.64 | 0.98x | | 128 | Cold | On | 1.67 | 1.64 | 0.98x | | 192 | Hot | Off | 2.95 | 3.51 | **1.19x** | | 192 | Hot | On | 2.95 | 3.51 | **1.19x** | | 192 | Cold | Off | 2.49 | 2.87 | **1.15x** | | 192 | Cold | On | 2.47 | 2.68 | 1.08x | | 256 | Hot | Off | 3.05 | 3.44 | **1.13x** | | 256 | Hot | On | 3.03 | 3.42 | **1.13x** | | 256 | Cold | Off | 2.77 | 3.00 | 1.08x | | 256 | Cold | On | 2.75 | 2.98 | 1.09x | | 512 | Hot | Off | 3.03 | 5.36 | **1.77x** | | 512 | Hot | On | 3.03 | 5.33 | **1.76x** | | 512 | Cold | Off | 2.85 | 3.43 | **1.21x** | | 512 | Cold | On | 2.86 | 3.45 | **1.21x** | | 1024 | Hot | Off | 2.99 | 5.13 | **1.71x** | | 1024 | Hot | On | 3.00 | 5.34 | **1.78x** | | 1024 | Cold | Off | 2.80 | 3.70 | **1.32x** | | 1024 | Cold | On | 2.79 | 3.66 | **1.31x** | ### 2-Bit Results #### SLS | Dim | Cache | Prefetch | Autovec | FP16 Reg | Speedup | |-----|-------|----------|---------|----------|---------| | 32 | Hot | Off | 0.97 | 2.40 | **2.49x** | | 32 | Hot | On | 0.98 | 2.39 | **2.45x** | | 32 | Cold | Off | 0.43 | 0.59 | **1.37x** | | 32 | Cold | On | 0.46 | 0.62 | **1.33x** | | 48 | Hot | Off | 0.93 | 2.46 | **2.66x** | | 48 | Hot | On | 0.89 | 2.46 | **2.77x** | | 48 | Cold | Off | 0.87 | 1.00 | **1.16x** | | 48 | Cold | On | 0.88 | 1.01 | **1.14x** | | 56 | Hot | Off | 0.93 | 2.40 | **2.60x** | | 56 | Hot | On | 0.93 | 2.39 | **2.58x** | | 56 | Cold | Off | 0.48 | 0.72 | **1.52x** | | 56 | Cold | On | 0.44 | 0.73 | **1.64x** | | 64 | Hot | Off | 0.90 | 2.41 | **2.69x** | | 64 | Hot | On | 0.90 | 2.43 | **2.70x** | | 64 | Cold | Off | 0.44 | 0.75 | **1.68x** | | 64 | Cold | On | 0.51 | 0.83 | **1.63x** | | 80 | Hot | Off | 0.88 | 1.65 | **1.87x** | | 80 | Hot | On | 0.88 | 1.64 | **1.87x** | | 80 | Cold | Off | 0.44 | 0.80 | **1.85x** | | 80 | Cold | On | 0.54 | 0.74 | **1.39x** | | 96 | Hot | Off | 0.87 | 1.60 | **1.84x** | | 96 | Hot | On | 0.87 | 1.59 | **1.83x** | | 96 | Cold | Off | 0.56 | 0.75 | **1.34x** | | 96 | Cold | On | 0.63 | 0.85 | **1.36x** | | 104 | Hot | Off | 0.83 | 1.66 | **2.00x** | | 104 | Hot | On | 0.85 | 1.66 | **1.95x** | | 104 | Cold | Off | 0.57 | 0.89 | **1.57x** | | 104 | Cold | On | 0.68 | 0.88 | **1.29x** | | 128 | Hot | Off | 0.85 | 1.40 | **1.65x** | | 128 | Hot | On | 0.86 | 1.41 | **1.64x** | | 128 | Cold | Off | 0.69 | 0.96 | **1.40x** | | 128 | Cold | On | 0.66 | 0.96 | **1.44x** | | 192 | Hot | Off | 0.80 | 1.64 | **2.05x** | | 192 | Hot | On | 0.82 | 1.65 | **2.02x** | | 192 | Cold | Off | 0.76 | 1.38 | **1.82x** | | 192 | Cold | On | 0.71 | 1.20 | **1.70x** | | 256 | Hot | Off | 0.83 | 1.60 | **1.93x** | | 256 | Hot | On | 0.83 | 1.61 | **1.93x** | | 256 | Cold | Off | 0.81 | 1.50 | **1.85x** | | 256 | Cold | On | 0.82 | 1.50 | **1.84x** | | 512 | Hot | Off | 0.80 | 1.96 | **2.45x** | | 512 | Hot | On | 0.80 | 1.95 | **2.45x** | | 512 | Cold | Off | 0.81 | 1.85 | **2.30x** | | 512 | Cold | On | 0.81 | 1.89 | **2.33x** | | 1024 | Hot | Off | 0.78 | 1.94 | **2.48x** | | 1024 | Hot | On | 0.79 | 1.93 | **2.43x** | | 1024 | Cold | Off | 0.80 | 1.87 | **2.33x** | | 1024 | Cold | On | 0.80 | 1.87 | **2.33x** | ### 2-Bit Results #### SLW | Dim | Cache | Prefetch | Autovec | FP16 Reg | Speedup | |-----|-------|----------|---------|----------|---------| | 32 | Hot | Off | 0.97 | 2.15 | **2.22x** | | 32 | Hot | On | 0.96 | 2.14 | **2.22x** | | 32 | Cold | Off | 0.42 | 0.55 | **1.30x** | | 32 | Cold | On | 0.43 | 0.58 | **1.34x** | | 48 | Hot | Off | 0.92 | 2.21 | **2.40x** | | 48 | Hot | On | 0.93 | 2.20 | **2.37x** | | 48 | Cold | Off | 0.88 | 1.02 | **1.16x** | | 48 | Cold | On | 0.88 | 1.00 | **1.14x** | | 56 | Hot | Off | 0.91 | 2.22 | **2.43x** | | 56 | Hot | On | 0.91 | 2.22 | **2.43x** | | 56 | Cold | Off | 0.46 | 0.73 | **1.57x** | | 56 | Cold | On | 0.43 | 0.66 | **1.51x** | | 64 | Hot | Off | 0.90 | 2.27 | **2.54x** | | 64 | Hot | On | 0.88 | 2.26 | **2.59x** | | 64 | Cold | Off | 0.48 | 0.71 | **1.47x** | | 64 | Cold | On | 0.14 | 0.71 | **5.25x** | | 80 | Hot | Off | 0.88 | 1.61 | **1.83x** | | 80 | Hot | On | 0.89 | 1.61 | **1.82x** | | 80 | Cold | Off | 0.49 | 0.72 | **1.48x** | | 80 | Cold | On | 0.52 | 0.73 | **1.41x** | | 96 | Hot | Off | 0.88 | 1.56 | **1.78x** | | 96 | Hot | On | 0.87 | 1.54 | **1.78x** | | 96 | Cold | Off | 0.58 | 0.81 | **1.39x** | | 96 | Cold | On | 0.60 | 0.85 | **1.41x** | | 104 | Hot | Off | 0.85 | 1.60 | **1.88x** | | 104 | Hot | On | 0.87 | 1.61 | **1.85x** | | 104 | Cold | Off | 0.61 | 0.84 | **1.38x** | | 104 | Cold | On | 0.65 | 0.91 | **1.40x** | | 128 | Hot | Off | 0.84 | 1.40 | **1.66x** | | 128 | Hot | On | 0.86 | 1.39 | **1.62x** | | 128 | Cold | Off | 0.62 | 0.91 | **1.48x** | | 128 | Cold | On | 0.67 | 1.03 | **1.54x** | | 192 | Hot | Off | 0.82 | 1.60 | **1.96x** | | 192 | Hot | On | 0.84 | 1.60 | **1.91x** | | 192 | Cold | Off | 0.74 | 1.24 | **1.66x** | | 192 | Cold | On | 0.69 | 1.26 | **1.84x** | | 256 | Hot | Off | 0.80 | 1.58 | **1.99x** | | 256 | Hot | On | 0.78 | 1.56 | **2.01x** | | 256 | Cold | Off | 0.81 | 1.47 | **1.82x** | | 256 | Cold | On | 0.81 | 1.48 | **1.83x** | | 512 | Hot | Off | 0.81 | 1.93 | **2.38x** | | 512 | Hot | On | 0.80 | 1.93 | **2.40x** | | 512 | Cold | Off | 0.81 | 1.85 | **2.29x** | | 512 | Cold | On | 0.81 | 1.84 | **2.28x** | | 1024 | Hot | Off | 0.78 | 1.92 | **2.46x** | | 1024 | Hot | On | 0.78 | 1.92 | **2.46x** | | 1024 | Cold | Off | 0.80 | 1.86 | **2.32x** | | 1024 | Cold | On | 0.79 | 1.86 | **2.34x** | ### N-Bit FP16 Output Average Speedup (SVE FP16 vs Autovec) | Bit Rate | Cache | SLS | SLW | |----------|-------|-----|-----| | 4 | Hot | **1.50x** | **1.46x** | | 4 | Cold | **1.25x** | **1.22x** | | 2 | Hot | **2.22x** | **2.12x** | | 2 | Cold | **1.59x** | **1.78x** | Reviewed By: helloguo Differential Revision: D99165466 fbshipit-source-id: f563c32fd1545f0fe989cff6f44810aefddb9900
1 parent 931739f commit 09faeb0

5 files changed

Lines changed: 2295 additions & 17 deletions

File tree

Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Benchmark for NBit embedding with fp16 output.
10+
// Compares SVE FP16 register path (via GenerateEmbeddingSpMDMNBitWithStrides)
11+
// against autovec baseline.
12+
13+
#include <algorithm>
14+
#include <cassert>
15+
#include <cstdint>
16+
#include <iostream>
17+
#include <numeric>
18+
#include <random>
19+
#include <vector>
20+
21+
#include "./BenchUtils.h"
22+
#include "fbgemm/Fbgemm.h"
23+
#include "fbgemm/FbgemmConvert.h"
24+
#include "src/EmbeddingSpMDMAutovec.h"
25+
26+
using namespace std;
27+
using namespace fbgemm;
28+
29+
static vector<vector<int>> GetInputs_() {
30+
vector<vector<int>> input_dims = {
31+
// batch size, number of rows of table, emb dim, avg length
32+
{10, 4000000, 4, 100},
33+
{10, 4000000, 7, 100},
34+
{10, 4000000, 32, 100},
35+
{10, 4000000, 33, 100},
36+
{10, 4000000, 48, 100},
37+
{10, 4000000, 56, 100},
38+
{10, 4000000, 64, 100},
39+
{10, 4000000, 65, 100},
40+
{10, 4000000, 80, 100},
41+
{10, 4000000, 96, 100},
42+
{10, 4000000, 104, 100},
43+
{10, 4000000, 128, 100},
44+
{10, 4000000, 143, 100},
45+
{10, 4000000, 192, 100},
46+
{10, 4000000, 256, 100},
47+
{10, 4000000, 263, 100},
48+
{10, 4000000, 512, 100},
49+
{10, 4000000, 1024, 100}};
50+
return input_dims;
51+
}
52+
53+
static int run_benchmark(
54+
int bit_rate,
55+
int batch_size,
56+
int num_rows,
57+
int embedding_dim,
58+
int average_len,
59+
bool normalize_by_lengths,
60+
bool use_32_bit_indices = false,
61+
bool prefetch = false) {
62+
int num_elem_per_byte = 8 / bit_rate;
63+
int fused_embedding_dim =
64+
(embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte +
65+
2 * sizeof(float16);
66+
default_random_engine generator;
67+
normal_distribution<float> embedding_distribution;
68+
69+
vector<uint8_t> fused_embedding_table(
70+
static_cast<size_t>(num_rows) * fused_embedding_dim);
71+
for (int i = 0; i < num_rows; i++) {
72+
for (int ii = 0;
73+
ii < (embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte;
74+
ii++) {
75+
fused_embedding_table[static_cast<size_t>(i) * fused_embedding_dim + ii] =
76+
2;
77+
}
78+
float16* scale_bias = reinterpret_cast<float16*>(
79+
&fused_embedding_table[static_cast<size_t>(i) * fused_embedding_dim] +
80+
(embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte);
81+
float scale = 2.0f;
82+
float bias = 1.0f;
83+
FloatToFloat16_ref(&scale, scale_bias, 1, true);
84+
FloatToFloat16_ref(&bias, scale_bias + 1, 1, true);
85+
}
86+
87+
uniform_int_distribution<int> length_distribution(
88+
1, std::min(2 * average_len + 1, num_rows));
89+
vector<int> offsets(batch_size + 1);
90+
offsets[0] = 0;
91+
for (int i = 0; i < batch_size; ++i) {
92+
offsets[i + 1] = offsets[i] + length_distribution(generator);
93+
}
94+
int lengths_sum = offsets[batch_size];
95+
cout << "lengths_sum " << lengths_sum << '\n';
96+
97+
vector<int64_t> indices;
98+
vector<int32_t> indices_32;
99+
vector<int> container(num_rows);
100+
for (int i = 0; i < batch_size; ++i) {
101+
iota(container.begin(), container.end(), 0);
102+
shuffle(container.begin(), container.end(), generator);
103+
copy(
104+
container.begin(),
105+
container.begin() + (offsets[i + 1] - offsets[i]),
106+
back_inserter(indices));
107+
}
108+
copy(begin(indices), end(indices), back_inserter(indices_32));
109+
110+
vector<float> weights(lengths_sum);
111+
for (int i = 0; i < lengths_sum; ++i) {
112+
weights[i] = embedding_distribution(generator);
113+
}
114+
115+
using OutType = float16;
116+
vector<OutType> output_ref(static_cast<size_t>(batch_size) * embedding_dim);
117+
vector<OutType> output(output_ref.size());
118+
vector<OutType> output_autovec(output_ref.size());
119+
120+
constexpr int NUM_WARMUP = 10;
121+
constexpr int NUM_ITER = 100;
122+
double bytes = static_cast<double>(lengths_sum) *
123+
(fused_embedding_dim + (use_32_bit_indices ? 4 : 8)) +
124+
batch_size * sizeof(int);
125+
constexpr int CACHE_LINE_LEN = 64;
126+
double bytes_padded = static_cast<double>(lengths_sum) *
127+
(CACHE_LINE_LEN *
128+
static_cast<int64_t>(
129+
(fused_embedding_dim + CACHE_LINE_LEN - 1) /
130+
CACHE_LINE_LEN) +
131+
(use_32_bit_indices ? 4 : 8)) +
132+
batch_size * sizeof(int);
133+
134+
for (bool has_weight : {false, true}) {
135+
// Main kernel via GenerateEmbeddingSpMDMNBitWithStrides (includes SVE FP16
136+
// dispatch when FBGEMM_SVE_FP16=1)
137+
auto kernel_32 = GenerateEmbeddingSpMDMNBitWithStrides<
138+
/*IndexType=*/int32_t,
139+
/*OffsetType=*/int32_t,
140+
/*OutType=*/OutType>(
141+
bit_rate,
142+
embedding_dim,
143+
has_weight,
144+
normalize_by_lengths,
145+
prefetch ? 16 : 0,
146+
/*is_weight_positional=*/false,
147+
/*use_offsets=*/true,
148+
/*output_stride=*/-1,
149+
/*input_stride=*/-1,
150+
/*scale_bias_last=*/true,
151+
/*is_bf16_out=*/false);
152+
auto kernel_64 = GenerateEmbeddingSpMDMNBitWithStrides<
153+
/*IndexType=*/int64_t,
154+
/*OffsetType=*/int32_t,
155+
/*OutType=*/OutType>(
156+
bit_rate,
157+
embedding_dim,
158+
has_weight,
159+
normalize_by_lengths,
160+
prefetch ? 16 : 0,
161+
/*is_weight_positional=*/false,
162+
/*use_offsets=*/true,
163+
/*output_stride=*/-1,
164+
/*input_stride=*/-1,
165+
/*scale_bias_last=*/true,
166+
/*is_bf16_out=*/false);
167+
168+
#ifdef FBGEMM_AUTOVEC_AVAILABLE
169+
auto kernel_32_autovec = GenerateEmbeddingSpMDMNBitWithStrides_autovec<
170+
/*IndexType=*/int32_t,
171+
/*OffsetType=*/int32_t,
172+
/*OutType=*/OutType>(
173+
bit_rate,
174+
embedding_dim,
175+
has_weight,
176+
normalize_by_lengths,
177+
prefetch ? 16 : 0,
178+
/*is_weight_positional=*/false,
179+
/*use_offsets=*/true,
180+
/*output_stride=*/-1,
181+
/*input_stride=*/-1,
182+
/*scale_bias_last=*/true,
183+
/*is_bf16_out=*/false,
184+
/*no_bag=*/false,
185+
/*output_bit_rate=*/-1);
186+
auto kernel_64_autovec = GenerateEmbeddingSpMDMNBitWithStrides_autovec<
187+
/*IndexType=*/int64_t,
188+
/*OffsetType=*/int32_t,
189+
/*OutType=*/OutType>(
190+
bit_rate,
191+
embedding_dim,
192+
has_weight,
193+
normalize_by_lengths,
194+
prefetch ? 16 : 0,
195+
/*is_weight_positional=*/false,
196+
/*use_offsets=*/true,
197+
/*output_stride=*/-1,
198+
/*input_stride=*/-1,
199+
/*scale_bias_last=*/true,
200+
/*is_bf16_out=*/false,
201+
/*no_bag=*/false,
202+
/*output_bit_rate=*/-1);
203+
#endif
204+
205+
for (bool flush_cache : {false, true}) {
206+
// Main kernel
207+
double t = measureWithWarmup(
208+
[&]() {
209+
if (use_32_bit_indices) {
210+
kernel_32(
211+
batch_size,
212+
lengths_sum,
213+
num_rows,
214+
fused_embedding_table.data(),
215+
indices_32.data(),
216+
offsets.data(),
217+
has_weight ? weights.data() : nullptr,
218+
output.data());
219+
} else {
220+
kernel_64(
221+
batch_size,
222+
lengths_sum,
223+
num_rows,
224+
fused_embedding_table.data(),
225+
indices.data(),
226+
offsets.data(),
227+
has_weight ? weights.data() : nullptr,
228+
output.data());
229+
}
230+
},
231+
NUM_WARMUP,
232+
NUM_ITER,
233+
[&]() {
234+
if (flush_cache) {
235+
cache_evict(fused_embedding_table);
236+
cache_evict(indices);
237+
cache_evict(indices_32);
238+
cache_evict(offsets);
239+
cache_evict(weights);
240+
cache_evict(output);
241+
}
242+
});
243+
244+
#ifdef FBGEMM_AUTOVEC_AVAILABLE
245+
// Autovec kernel
246+
double t_autovec = measureWithWarmup(
247+
[&]() {
248+
if (use_32_bit_indices) {
249+
kernel_32_autovec(
250+
batch_size,
251+
lengths_sum,
252+
num_rows,
253+
fused_embedding_table.data(),
254+
indices_32.data(),
255+
offsets.data(),
256+
has_weight ? weights.data() : nullptr,
257+
output_autovec.data());
258+
} else {
259+
kernel_64_autovec(
260+
batch_size,
261+
lengths_sum,
262+
num_rows,
263+
fused_embedding_table.data(),
264+
indices.data(),
265+
offsets.data(),
266+
has_weight ? weights.data() : nullptr,
267+
output_autovec.data());
268+
}
269+
},
270+
NUM_WARMUP,
271+
NUM_ITER,
272+
[&]() {
273+
if (flush_cache) {
274+
cache_evict(fused_embedding_table);
275+
cache_evict(indices);
276+
cache_evict(indices_32);
277+
cache_evict(offsets);
278+
cache_evict(weights);
279+
cache_evict(output_autovec);
280+
}
281+
});
282+
#endif
283+
284+
// Output
285+
cout << "out type fp16, ";
286+
if (has_weight) {
287+
cout << "SLW(WEIGHTED), ";
288+
} else {
289+
cout << "SLS, ";
290+
}
291+
if (normalize_by_lengths) {
292+
cout << "normalize, ";
293+
}
294+
if (flush_cache) {
295+
cout << "cache flushed, ";
296+
} else {
297+
cout << "cache not flushed, ";
298+
}
299+
if (prefetch) {
300+
cout << "prefetch on, ";
301+
} else {
302+
cout << "prefetch off, ";
303+
}
304+
305+
cout << "b/w, " << bytes / 1e9 / t << ", GB/s, "
306+
<< "effective b/w, " << bytes_padded / 1e9 / t << ", GB/s, "
307+
<< "time, " << t;
308+
#ifdef FBGEMM_AUTOVEC_AVAILABLE
309+
cout << ", autovec b/w, " << bytes / 1e9 / t_autovec << ", GB/s, "
310+
<< "autovec eff. b/w, " << bytes_padded / 1e9 / t_autovec
311+
<< ", GB/s, "
312+
<< "autovec time, " << t_autovec << ", speedup vs autovec, "
313+
<< t_autovec / t;
314+
#endif
315+
cout << '\n';
316+
cout.flush();
317+
} // flush_cache
318+
} // has_weight
319+
return 0;
320+
}
321+
322+
int main() {
323+
vector<vector<int>> inputs(GetInputs_());
324+
325+
for (int bit_rate : {4, 2}) {
326+
for (auto& input : inputs) {
327+
assert(input.size() > 3);
328+
int batch_size = input[0];
329+
int num_rows = input[1];
330+
int embedding_dim = input[2];
331+
int average_len = input[3];
332+
333+
cout << "bit_rate, " << bit_rate << ", batch size, " << batch_size
334+
<< ", num rows, " << num_rows << ", emb dim, " << embedding_dim
335+
<< ", avg length, " << average_len << '\n';
336+
337+
for (bool normalize : {false, true}) {
338+
// 64-bit indices, no prefetch
339+
cout << "64 bit indices, ";
340+
run_benchmark(
341+
bit_rate,
342+
batch_size,
343+
num_rows,
344+
embedding_dim,
345+
average_len,
346+
normalize);
347+
348+
// 64-bit indices, with prefetch
349+
cout << "64 bit indices with prefetching, ";
350+
run_benchmark(
351+
bit_rate,
352+
batch_size,
353+
num_rows,
354+
embedding_dim,
355+
average_len,
356+
normalize,
357+
/*use_32_bit_indices=*/false,
358+
/*prefetch=*/true);
359+
360+
// 32-bit indices, no prefetch
361+
cout << "32 bit indices, ";
362+
run_benchmark(
363+
bit_rate,
364+
batch_size,
365+
num_rows,
366+
embedding_dim,
367+
average_len,
368+
normalize,
369+
/*use_32_bit_indices=*/true,
370+
/*prefetch=*/false);
371+
372+
// 32-bit indices, with prefetch
373+
cout << "32 bit indices with prefetching, ";
374+
run_benchmark(
375+
bit_rate,
376+
batch_size,
377+
num_rows,
378+
embedding_dim,
379+
average_len,
380+
normalize,
381+
/*use_32_bit_indices=*/true,
382+
/*prefetch=*/true);
383+
}
384+
}
385+
}
386+
return 0;
387+
}

0 commit comments

Comments
 (0)