Skip to content

Commit 4aaaa93

Browse files
Shuyang Liufacebook-github-bot
authored andcommitted
Add multithreading to table lookup (#5849)
Summary: X-link: facebookresearch/FBGEMM#2767 ## What Parallelizes the per-table loop in the CPU TBE forward kernel (IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The per-table loop is embarrassingly parallel — each table reads its own weight slice and writes a disjoint slice of the output — so fanning tables out across threads gives near-linear speedup on table-heavy inference models. Gated by the TBE_TABLE_THREADS env var: - TBE_TABLE_THREADS=1 (default): unchanged sequential behavior. - TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with dynamic scheduling (good load balancing when table sizes are skewed). ## Default behavior is unchanged When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop body sequentially with no try/catch wrapper and no thread-local-state guard, so the default path is functionally identical to the pre-change code: same iteration order, the DEVICE-placement TORCH_CHECK in its original per-table position, same error semantics, and the same generated machine code for the body. The only always-on changes are mechanical and behavior-preserving (loop-local `weights` pointer, int64 loop index). ## Design - Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than at::parallel_for, so TBE gets its own thread count independent of the global intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1). - Thread count is read once from the env var and cached (thread-safe static init), and clamped to the number of tables. ## Correctness - Removed the function-scoped `weights_acc` pointer, which every iteration overwrote — a data race once the loop is parallel. Replaced with a loop-local pointer (identical pointer value). Every other variable in the loop body is already loop-local, and each table writes a disjoint output slice (output_acc + D_start), so results are bitwise-identical to the sequential path. - The per-table DEVICE-placement TORCH_CHECK stays in its original position. In the threaded path it — like any other throw from the loop body (kernel errors, at::arange checks) — is captured (first one wins) and rethrown after the join, so no exception escapes the OpenMP region. - Worker threads restore the caller's at::ThreadLocalState (dispatch keys, grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g. at::arange in the nobag path) run with the correct thread-local context. ## Verification - Builds clean (mode/opt); confirmed OpenMP is actually enabled for this target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op, even though no -fopenmp appears in the TARGETS; the fbcode default toolchain supplies it). - nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets). ## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg) | Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency | | 8 | 1 | 1,436 | 4.38 | --- | --- | | 8 | 2 | 1,042 | 6.03 | 1.38x | 69% | | 8 | 4 | 844 | 7.46 | 1.70x | 43% | | 8 | 8 | 773 | 8.14 | 1.86x | 23% | | 32 | 1 | 4,797 | 5.25 | --- | --- | | 32 | 2 | 2,830 | 8.90 | 1.69x | 85% | | 32 | 4 | 2,003 | 12.60 | 2.40x | 60% | | 32 | 8 | 1,633 | 15.49 | 2.94x | 37% | | 64 | 1 | 10,132 | 4.97 | --- | --- | | 64 | 2 | 6,767 | 7.44 | 1.50x | 75% | | 64 | 4 | 4,864 | 10.35 | 2.08x | 52% | | 64 | 8 | 4,033 | 12.50 | 2.51x | 31% | 2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher counts due to fixed fork/join overhead per call. This matches the production recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction of total inference latency). Reviewed By: helloguo, q10 Differential Revision: D102867249
1 parent fa211b0 commit 4aaaa93

1 file changed

Lines changed: 75 additions & 6 deletions

File tree

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
#include <immintrin.h>
2626
#include <emmintrin.h>
2727
#endif
28+
#include <charconv>
2829
#include <cstring>
30+
#include <ATen/ThreadLocalState.h>
2931

3032
using namespace fbgemm_gpu;
3133

@@ -37,6 +39,74 @@ C10_NOINLINE void check_fp8_params(int64_t fp8_exponent_bits, int64_t fp8_expone
3739
TORCH_CHECK(fp8_exponent_bits > 0 && fp8_exponent_bias > 0, "FP8 requires fp8_exponent_bits > 0 (got ", fp8_exponent_bits, ") and fp8_exponent_bias > 0 (got ", fp8_exponent_bias, ")");
3840
}
3941

42+
inline int get_tbe_table_threads() {
43+
static const int n = []() {
44+
const char* env = std::getenv("TBE_TABLE_THREADS");
45+
if (!env || *env == '\0') {
46+
return 1;
47+
}
48+
int val = 0;
49+
auto [ptr, ec] = std::from_chars(env, env + std::strlen(env), val);
50+
if (ec != std::errc{} || *ptr != '\0') {
51+
return 1;
52+
}
53+
return std::max<int>(1, val);
54+
}();
55+
return n;
56+
}
57+
58+
template <typename F>
59+
inline void parallel_for_table_threads(
60+
int64_t begin,
61+
int64_t end,
62+
const F& f) {
63+
const int num_threads = get_tbe_table_threads();
64+
if (begin >= end || num_threads <= 1) {
65+
f(begin, end);
66+
return;
67+
}
68+
// Don't spawn more threads than there are tables.
69+
// [[maybe_unused]]: only consumed by `#pragma omp parallel num_threads(...)`
70+
// below. In OSS builds OpenMP is not enabled for this target, so the pragma is
71+
// ignored and the variable is otherwise unused — without this attribute that
72+
// trips -Werror=unused-variable (gcc) / -Werror,-Wunused-variable (clang) and
73+
// fails the OSS FBGEMM_GPU CPU/CUDA build_artifact CI. (fbcode enables OpenMP,
74+
// so the pragma consumes it there.)
75+
[[maybe_unused]] const int effective_threads =
76+
static_cast<int>(std::min<int64_t>(num_threads, end - begin));
77+
// Raw OpenMP does not carry the caller's ThreadLocalState (dispatch keys,
78+
// grad/inference mode, autocast, ...) to worker threads the way
79+
// at::parallel_for does. Capture it here and restore it on each worker,
80+
// otherwise ATen calls inside `f` (e.g. at::arange in the nobag path) run
81+
// with the wrong thread-local context.
82+
const at::ThreadLocalState tls;
83+
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
84+
std::exception_ptr eptr;
85+
#pragma omp parallel num_threads(effective_threads)
86+
{
87+
try {
88+
const at::ThreadLocalStateGuard tls_guard(tls);
89+
#pragma omp for schedule(dynamic) nowait
90+
for (int64_t t = begin; t < end; ++t) {
91+
try {
92+
f(t, t + 1);
93+
} catch (...) {
94+
if (!err_flag.test_and_set()) {
95+
eptr = std::current_exception();
96+
}
97+
}
98+
}
99+
} catch (...) {
100+
if (!err_flag.test_and_set()) {
101+
eptr = std::current_exception();
102+
}
103+
}
104+
}
105+
if (eptr) {
106+
std::rethrow_exception(eptr);
107+
}
108+
}
109+
40110
inline uint32_t pruned_hash_function(uint32_t h) {
41111
// MurmorHash3 32-bit mixing function.
42112
h ^= h >> 16;
@@ -240,8 +310,6 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
240310
}
241311

242312
const int32_t* weights_placements_ptr = weights_placements.const_data_ptr<int32_t>();
243-
const uint8_t* weights_acc;
244-
245313
const auto* weights_tys_acc = weights_tys.const_data_ptr<uint8_t>();
246314

247315
DISPATCH_OUTPUT_TYPES(output.scalar_type(), "intn_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", [&] {
@@ -280,7 +348,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
280348
std::ranges::unique(physical_offsets).begin(),
281349
physical_offsets.end());
282350

283-
for (const auto t : c10::irange(T)) {
351+
parallel_for_table_threads(0, T, [&](int64_t begin, int64_t end) {
352+
for (int64_t t = begin; t < end; ++t) {
284353
{% if not nobag %}
285354
const auto* D_offsets_acc = D_offsets.const_data_ptr<int32_t>();
286355
const int32_t D_start = D_offsets_acc[t];
@@ -294,8 +363,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
294363
const auto placement = static_cast<PlacementType>(weights_placements_ptr[t]);
295364
TORCH_CHECK(placement != PlacementType::DEVICE);
296365
const auto& weight_tensor = (placement == PlacementType::HOST) ? dev_weights : uvm_weights;
297-
weights_acc = weight_tensor.const_data_ptr<uint8_t>();
298-
const uint8_t* weights = &weights_acc[weights_offsets_acc[t]];
366+
const uint8_t* weights = weight_tensor.const_data_ptr<uint8_t>() + weights_offsets_acc[t];
299367
const auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
300368
if (output_is_int8) {
301369
TORCH_CHECK(weight_ty == SparseType::INT8, "int8 output are only supported for int8 weights");
@@ -451,7 +519,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
451519
num_rows,
452520
/*allow_minus_one=*/true);
453521
}
454-
}
522+
}
523+
});
455524
return;
456525
});
457526
});

0 commit comments

Comments
 (0)