Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <ATen/Parallel.h>

#include "fbgemm_gpu/utils/cpu_utils.h"
#include "fbgemm_gpu/utils/embedding_cpu_threading.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm/FbgemmEmbedding.h"
Expand All @@ -26,6 +27,8 @@
#include <emmintrin.h>
#endif
#include <cstring>
#include <exception>
#include <ATen/ThreadLocalState.h>

using namespace fbgemm_gpu;

Expand All @@ -37,6 +40,65 @@ C10_NOINLINE void check_fp8_params(int64_t fp8_exponent_bits, int64_t fp8_expone
TORCH_CHECK(fp8_exponent_bits > 0 && fp8_exponent_bias > 0, "FP8 requires fp8_exponent_bits > 0 (got ", fp8_exponent_bits, ") and fp8_exponent_bias > 0 (got ", fp8_exponent_bias, ")");
}

template <typename F>
inline void parallel_for_table_threads(
int begin,
int end,
const F& f) {

const int num_threads = choose_num_threads(end - begin);

// Short-cut, no need to invoke openmp
if (num_threads == 1) {
f(begin, end);
return;
}

// Raw OpenMP does not carry the caller's ThreadLocalState (dispatch keys,
// grad/inference mode, autocast, ...) to worker threads the way
// at::parallel_for does. Capture it here and restore it on each worker,
// otherwise ATen calls inside `f` (e.g. at::arange in the nobag path) run
// with the wrong thread-local context.
const at::ThreadLocalState tls;

bool have_err = false;
std::exception_ptr eptr;

#pragma omp parallel num_threads(num_threads)
{
try {
const at::ThreadLocalStateGuard tls_guard(tls);
#pragma omp for schedule(dynamic) nowait
for (int t = begin; t < end; ++t) {
try {
f(t, t + 1);
} catch (...) {
// std::exception_ptr is not atomic,
// hence we need a critical section here in case multiple threads have exceptions
#pragma omp critical(tbe_table_threads_err)
{
if (!have_err) {
have_err = true;
eptr = std::current_exception();
}
}
}
}
} catch (...) {
#pragma omp critical(tbe_table_threads_err)
{
if (!have_err) {
have_err = true;
eptr = std::current_exception();
}
}
}
}
if (eptr) {
std::rethrow_exception(eptr);
}
}

inline uint32_t pruned_hash_function(uint32_t h) {
// MurmorHash3 32-bit mixing function.
h ^= h >> 16;
Expand Down Expand Up @@ -240,8 +302,6 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
}

const int32_t* weights_placements_ptr = weights_placements.const_data_ptr<int32_t>();
const uint8_t* weights_acc;

const auto* weights_tys_acc = weights_tys.const_data_ptr<uint8_t>();

DISPATCH_OUTPUT_TYPES(output.scalar_type(), "intn_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", [&] {
Expand Down Expand Up @@ -280,7 +340,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
std::ranges::unique(physical_offsets).begin(),
physical_offsets.end());

for (const auto t : c10::irange(T)) {
parallel_for_table_threads(0, T, [&](int begin, int end) {
for (int t = begin; t < end; ++t) {
{% if not nobag %}
const auto* D_offsets_acc = D_offsets.const_data_ptr<int32_t>();
const int32_t D_start = D_offsets_acc[t];
Expand All @@ -294,8 +355,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
const auto placement = static_cast<PlacementType>(weights_placements_ptr[t]);
TORCH_CHECK(placement != PlacementType::DEVICE);
const auto& weight_tensor = (placement == PlacementType::HOST) ? dev_weights : uvm_weights;
weights_acc = weight_tensor.const_data_ptr<uint8_t>();
const uint8_t* weights = &weights_acc[weights_offsets_acc[t]];
const uint8_t* weights = weight_tensor.const_data_ptr<uint8_t>() + weights_offsets_acc[t];
const auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
if (output_is_int8) {
TORCH_CHECK(weight_ty == SparseType::INT8, "int8 output are only supported for int8 weights");
Expand Down Expand Up @@ -451,7 +511,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
num_rows,
/*allow_minus_one=*/true);
}
}
}
});
return;
});
});
Expand Down
64 changes: 64 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/embedding_cpu_threading.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <algorithm>
#include <charconv>
#include <cstdlib>
#include <cstring>

namespace fbgemm_gpu {

// Default work-granularity (tables per thread)
constexpr int DEFAULT_TABLES_PER_THREAD = 16;

inline int
calculate_num_threads(int num_tables, int cap, int tables_per_thread) {
if (cap <= 1 || num_tables <= 1) {
return 1;
}
const int num_threads = num_tables / tables_per_thread;
return std::clamp<int>(num_threads, 1, cap);
}

inline int get_env_int(const char* name, int default_val) {
const char* env = std::getenv(name);
if (!env || *env == '\0') {
return default_val;
}
int val = 0;
auto [ptr, ec] = std::from_chars(env, env + std::strlen(env), val);
if (ec != std::errc{} || *ptr != '\0') {
return default_val;
}
return std::max<int>(1, val);
}

// Thread-count cap from env FBGEMM_TBE_MAX_NUM_THREADS
inline int get_tbe_max_num_threads() {
static const int n = get_env_int("FBGEMM_TBE_MAX_NUM_THREADS", 1);
return n;
}

// Work-granularity from env FBGEMM_TBE_MIN_TABLES_PER_THREAD
// We are using the number of tables as approximated
// minimal workload per thread (default 16) to avoid
// threading overhead
inline int get_tbe_min_tables_per_thread() {
static const int n = get_env_int(
"FBGEMM_TBE_MIN_TABLES_PER_THREAD", DEFAULT_TABLES_PER_THREAD);
return n;
}

inline int choose_num_threads(int num_tables) {
return calculate_num_threads(
num_tables, get_tbe_max_num_threads(), get_tbe_min_tables_per_thread());
}

} // namespace fbgemm_gpu
71 changes: 71 additions & 0 deletions fbgemm_gpu/test/tbe/inference/embedding_cpu_threading_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include "fbgemm_gpu/utils/embedding_cpu_threading.h"

using fbgemm_gpu::calculate_num_threads;
using fbgemm_gpu::DEFAULT_TABLES_PER_THREAD;

namespace {
constexpr int G = DEFAULT_TABLES_PER_THREAD; // 16 by default
} // namespace

// The headline guarantee: with no env var set, FBGEMM_TBE_MAX_NUM_THREADS
// defaults to a cap of 1, and the per-call decision is ALWAYS 1 (serial) --
// regardless of work size or granularity. So the no-env-var path is identical
// to single-threaded TBE.
TEST(EmbeddingCpuThreadingTest, DefaultCapIsAlwaysSerial) {
// cap=1 == get_tbe_max_num_threads() when FBGEMM_TBE_MAX_NUM_THREADS is
// unset.
for (int work : {0, 1, 2, 13, 32, 358, 100000}) {
EXPECT_EQ(calculate_num_threads(work, 1, G), 1)
<< "cap=1 must stay serial for work=" << work;
// And it is independent of the granularity knob.
EXPECT_EQ(calculate_num_threads(work, 1, 1), 1);
EXPECT_EQ(calculate_num_threads(work, 1, 64), 1);
}
}

// Trivial-work calls never thread, even with a higher cap.
TEST(EmbeddingCpuThreadingTest, TrivialWorkIsSerial) {
for (int cap : {1, 2, 4, 8}) {
EXPECT_EQ(calculate_num_threads(0, cap, G), 1);
EXPECT_EQ(calculate_num_threads(1, cap, G), 1);
}
}

// Default granularity (G=16) puts the threading onset at 2*G = 32 tables,
// matching the validated A/B gate: small few-table lookups stay serial; large
// gathers thread.
TEST(EmbeddingCpuThreadingTest, DefaultGuardOnsetAt32) {
// 2T cap.
EXPECT_EQ(calculate_num_threads(7, 2, G), 1); // dpa remote_ro_event lookups
EXPECT_EQ(calculate_num_threads(13, 2, G), 1);
EXPECT_EQ(calculate_num_threads(31, 2, G), 1); // just below onset
EXPECT_EQ(calculate_num_threads(32, 2, G), 2); // onset
EXPECT_EQ(calculate_num_threads(358, 2, G), 2); // dpa remote_ro -> cap
}

// Grading scales one thread per G tables, clamped to the cap.
TEST(EmbeddingCpuThreadingTest, GradesUpToCap) {
// 4T cap, G=16.
EXPECT_EQ(calculate_num_threads(13, 4, G), 1); // event -> serial
EXPECT_EQ(calculate_num_threads(32, 4, G), 2); // 32/16 = 2
EXPECT_EQ(calculate_num_threads(48, 4, G), 3); // 48/16 = 3
EXPECT_EQ(calculate_num_threads(64, 4, G), 4); // 64/16 = 4 (cap)
EXPECT_EQ(calculate_num_threads(358, 4, G), 4); // clamped to cap
}

// G=1 reproduces the old unconditional behavior: thread every non-trivial call.
TEST(EmbeddingCpuThreadingTest, GranularityOneThreadsEverything) {
EXPECT_EQ(calculate_num_threads(2, 2, 1), 2);
EXPECT_EQ(calculate_num_threads(13, 2, 1), 2);
EXPECT_EQ(calculate_num_threads(13, 4, 1), 4);
}
77 changes: 77 additions & 0 deletions fbgemm_gpu/test/tbe/inference/nbit_forward_threading_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import os
import subprocess
import tempfile
import unittest
from typing import Optional

import torch

# Path to the worker binary, injected via `$(location ...)` in the BUCK env. This
# test relies on a sibling python_binary located through buck, so it only runs in
# the fbcode build; in the OSS (pytest/CMake) build the env var is absent and the
# test is skipped (use .get(), not [], so import never raises during collection).
_WORKER: Optional[str] = os.environ.get("NBIT_THREADING_WORKER")


def _run(
out_path: str, threads: Optional[int], tables_per_thread: Optional[int]
) -> torch.Tensor:
"""Run the worker in a fresh process with the given threading env and load
its forward output. The thread count is read once (cached) at the first
kernel call, so each setting needs its own process."""
worker = _WORKER
assert worker is not None # guaranteed by the skipUnless on the test class
env = dict(os.environ)
env.pop("FBGEMM_TBE_MAX_NUM_THREADS", None)
env.pop("FBGEMM_TBE_MIN_TABLES_PER_THREAD", None)
if threads is not None:
env["FBGEMM_TBE_MAX_NUM_THREADS"] = str(threads)
if tables_per_thread is not None:
env["FBGEMM_TBE_MIN_TABLES_PER_THREAD"] = str(tables_per_thread)
subprocess.run([worker, out_path], env=env, check=True)
return torch.load(out_path)


@unittest.skipUnless(
_WORKER is not None,
"requires the fbcode worker binary via NBIT_THREADING_WORKER ($(location)); "
"not available in the OSS build",
)
class NBitForwardThreadingTest(unittest.TestCase):
def test_threading_does_not_change_result(self) -> None:
# Each config maps to (FBGEMM_TBE_MAX_NUM_THREADS, FBGEMM_TBE_MIN_TABLES_PER_THREAD).
# Outputs must be BITWISE identical across all of them: table-threading
# partitions independent per-table work into disjoint output slices, with
# no cross-thread reduction, so there is no floating-point reordering.
configs = {
"single_thread": (1, None), # explicit serial
"default_no_env": (None, None), # no env var -> serial path
"2T_guard": (2, None), # 2 threads, default granularity (G=16)
"2T_all": (2, 1), # 2 threads, thread every call
"4T_all": (4, 1), # 4 threads, thread every call
}
with tempfile.TemporaryDirectory() as d:
outputs = {
name: _run(os.path.join(d, f"{name}.pt"), thr, tpt)
for name, (thr, tpt) in configs.items()
}
base = outputs["single_thread"]
self.assertTrue(torch.isfinite(base).all(), "reference output not finite")
for name, out in outputs.items():
self.assertEqual(out.shape, base.shape, f"{name}: shape mismatch")
self.assertTrue(
torch.equal(out, base),
f"{name} output differs from single_thread (threading changed the result)",
)


if __name__ == "__main__":
unittest.main()
54 changes: 54 additions & 0 deletions fbgemm_gpu/test/tbe/inference/nbit_forward_threading_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

# Worker for nbit_forward_threading_test: builds a deterministic CPU int-nbit TBE op
# and writes its forward output to the path given as argv[1]. The driver runs this
# under different FBGEMM_TBE_MAX_NUM_THREADS / FBGEMM_TBE_MIN_TABLES_PER_THREAD env values (read once,
# at the first kernel call, hence a separate process per setting) and checks the
# outputs are bitwise identical -- i.e. table-threading does not change the result.
import sys

import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
EmbeddingLocation,
PoolingMode,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)


def main() -> None:
out_path = sys.argv[1]
# T=40 > the default threading onset (2*G = 32 at G=16), so even the
# default-granularity arm (FBGEMM_TBE_MAX_NUM_THREADS=2, no FBGEMM_TBE_MIN_TABLES_PER_THREAD)
# genuinely spawns threads rather than falling back to the serial path.
T, E, D, B, L = 40, 1000, 16, 8, 6

# Deterministic weights: same seed + same torch build => identical across the
# worker processes the driver spawns, so the only variable is the thread count.
torch.manual_seed(0)
cc = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[("", E, D, SparseType.INT8, EmbeddingLocation.HOST)] * T,
pooling_mode=PoolingMode.SUM,
device="cpu",
output_dtype=SparseType.FP16,
)
cc.fill_random_weights()

# Deterministic indices/offsets (no RNG): T*B bags, each pooling L indices.
indices = (torch.arange(T * B * L) % E).to(torch.int32)
offsets = (torch.arange(T * B + 1) * L).to(torch.int32)

out = cc(indices, offsets)
torch.save(out.cpu(), out_path)


if __name__ == "__main__":
main()
Loading