From 9200e94d6ddd51d96f17386718c5b04fb8af536e Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 19 Sep 2025 10:22:08 +0430 Subject: [PATCH] [Kernel][Comms] feat: add custom all-gather kernels --- CMakeLists.txt | 1 + kernels/all_gather/custom_all_gather.cu | 205 +++++++++++++++++++++++ kernels/all_gather/custom_all_gather.cuh | 30 ++++ kernels/ops.h | 8 + kernels/torch_bindings.cpp | 17 ++ 5 files changed, 261 insertions(+) create mode 100644 kernels/all_gather/custom_all_gather.cu create mode 100644 kernels/all_gather/custom_all_gather.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index c555a5ed44..b70430ded2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -257,6 +257,7 @@ set(APHRODITE_EXT_SRC "kernels/quantization/activation_kernels.cu" "kernels/cuda_utils_kernels.cu" "kernels/all_reduce/custom_all_reduce.cu" + "kernels/all_gather/custom_all_gather.cu" "kernels/torch_bindings.cpp") if(APHRODITE_GPU_LANG STREQUAL "CUDA") diff --git a/kernels/all_gather/custom_all_gather.cu b/kernels/all_gather/custom_all_gather.cu new file mode 100644 index 0000000000..49346d38a2 --- /dev/null +++ b/kernels/all_gather/custom_all_gather.cu @@ -0,0 +1,205 @@ +#include "custom_all_gather.cuh" + +#include +#include +#include + +#if ENABLE_MULTI_DEVICE +namespace aphrodite { +namespace runtime { +namespace TorchUtils { +enum class DataType : int32_t { + kFP32, + kFP16, + kBF16, + kINT32, + kINT8, + kUINT8, + kINT64, + kBOOL, + kCOUNT, +}; + +inline DataType dataType(torch::ScalarType type) { + switch (type) { + case torch::ScalarType::Float: return DataType::kFP32; + case torch::ScalarType::Half: return DataType::kFP16; + case torch::ScalarType::BFloat16: return DataType::kBF16; + case torch::ScalarType::Int: return DataType::kINT32; + case torch::ScalarType::Char: return DataType::kINT8; + case torch::ScalarType::Byte: return DataType::kUINT8; + case torch::ScalarType::Long: return DataType::kINT64; + case torch::ScalarType::Bool: return DataType::kBOOL; + default: TORCH_CHECK(false, "Unsupported data type"); + } +} +} // namespace TorchUtils +} // namespace runtime +} // namespace aphrodite + +#ifndef TLLM_LOG_TRACE +#define TLLM_LOG_TRACE(...) \ + do { \ + } while (0) +#endif + +#ifndef COMM_SESSION +struct DummyCommSession { + int getRank() const { return 0; } +} COMM_SESSION; +#endif + +namespace aphrodite { + +namespace { + +const std::unordered_map* getDtypeMap() { + static const std::unordered_map dtypeMap = { + {aphrodite::runtime::TorchUtils::DataType::kFP32, ncclFloat}, + {aphrodite::runtime::TorchUtils::DataType::kFP16, ncclHalf}, + {aphrodite::runtime::TorchUtils::DataType::kBF16, ncclBfloat16}, + {aphrodite::runtime::TorchUtils::DataType::kINT32, ncclInt32}, + {aphrodite::runtime::TorchUtils::DataType::kINT8, ncclInt8}, + {aphrodite::runtime::TorchUtils::DataType::kUINT8, ncclUint8}, + {aphrodite::runtime::TorchUtils::DataType::kINT64, ncclInt64}, + {aphrodite::runtime::TorchUtils::DataType::kBOOL, ncclUint8}, // NCCL doesn't have bool, use uint8 + }; + return &dtypeMap; +} + +class AllgatherOp { +public: + AllgatherOp(std::set group, ncclComm_t comm) + : mGroup(std::move(group)), mNcclComm(comm) + { + } + + ~AllgatherOp() = default; + + void initialize() { + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); + // NCCL communicator is now passed in the constructor + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); + } + + std::vector run_list(torch::TensorList input_list, const std::optional>& sizes) { + TORCH_CHECK(mNcclComm != nullptr, "NCCL communicator not initialized."); + + bool use_nccl_allgather = !sizes.has_value() || + std::all_of(sizes.value().begin(), sizes.value().end(), + [&sizes](int64_t size) { return size == sizes.value()[0]; }); + + int64_t sum_sizes = sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{}) : 0; + + std::vector output_list; + output_list.reserve(input_list.size()); + + // NCCLCHECK_THROW(ncclGroupStart()); // Group operations might be managed by Aphrodite's distributed backend + for (auto const& input : input_list) { + TORCH_CHECK(input.is_cuda(), "Input tensor must be on CUDA device."); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + auto type = aphrodite::runtime::TorchUtils::dataType(input.scalar_type()); + + std::vector outputShape = input.sizes().vec(); + if (sizes.has_value()) { + outputShape[0] = sum_sizes; + } else { + outputShape[0] *= mGroup.size(); + } + auto output = torch::empty(outputShape, input.options()); + + if (use_nccl_allgather) { + AT_CUDA_CHECK(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(), (*getDtypeMap())[type], + mNcclComm, stream)); + } else { + size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{}); + int64_t split_offset = 0; + + for (int root_idx = 0; root_idx < static_cast(mGroup.size()); ++root_idx) { + auto it = mGroup.begin(); + std::advance(it, root_idx); + int root_rank = *it; + + auto split_size = sizes.value()[root_idx]; + AT_CUDA_CHECK(ncclBroadcast(input.data_ptr(), + output.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).mutable_data_ptr(), + numel_base * split_size, (*getDtypeMap())[type], root_rank, mNcclComm, stream)); + split_offset += split_size; + } + } + output_list.push_back(output); + } + // NCCLCHECK_THROW(ncclGroupEnd()); // Group operations might be managed by Aphrodite's distributed backend + return output_list; + } + + torch::Tensor run(torch::Tensor input, const std::optional>& sizes) { + return run_list({input}, sizes)[0]; + } + +private: + std::set mGroup; + ncclComm_t mNcclComm; // Stored NCCL communicator +}; + +} // namespace + +// C++ functions to be exposed to Python +fptr_t init_custom_ag(const std::vector& group_ranks, int64_t nccl_comm_ptr) { + std::set group_set(group_ranks.begin(), group_ranks.end()); + ncclComm_t comm = reinterpret_cast(nccl_comm_ptr); + AllgatherOp* op = new AllgatherOp(group_set, comm); + op->initialize(); + return reinterpret_cast(op); +} + +void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output, + const std::optional>& sizes) { + AllgatherOp* op = reinterpret_cast(_ag_op); + torch::Tensor result = op->run(input, sizes); + output.copy_(result); // Copy result to the provided output tensor +} + +std::vector all_gather_list( + fptr_t _ag_op, const std::vector& input_list, + const std::optional>& sizes) { + AllgatherOp* op = reinterpret_cast(_ag_op); + return op->run_list(input_list, sizes); +} + +void dispose_custom_ag(fptr_t _ag_op) { + AllgatherOp* op = reinterpret_cast(_ag_op); + delete op; +} + +} // namespace aphrodite + +#else // ENABLE_MULTI_DEVICE + +// Dummy implementations for when multi-device is not enabled +namespace aphrodite { + +fptr_t init_custom_ag(const std::vector& group_ranks, int64_t nccl_comm_ptr) { + TORCH_CHECK(false, "Multi-device support not enabled."); + return 0; +} + +void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output, + const std::optional>& sizes) { + TORCH_CHECK(false, "Multi-device support not enabled."); +} + +std::vector all_gather_list( + fptr_t _ag_op, const std::vector& input_list, + const std::optional>& sizes) { + TORCH_CHECK(false, "Multi-device support not enabled."); + return {}; +} + +void dispose_custom_ag(fptr_t _ag_op) { + TORCH_CHECK(false, "Multi-device support not enabled."); +} + +} // namespace aphrodite + +#endif // ENABLE_MULTI_DEVICE \ No newline at end of file diff --git a/kernels/all_gather/custom_all_gather.cuh b/kernels/all_gather/custom_all_gather.cuh new file mode 100644 index 0000000000..9cdabdcd9d --- /dev/null +++ b/kernels/all_gather/custom_all_gather.cuh @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include +#include +#include + +#if ENABLE_MULTI_DEVICE +#include +#endif + +namespace aphrodite { + +// Forward declaration of the AllgatherOp class +class AllgatherOp; + +// Define fptr_t for passing C++ object pointers to Python +using fptr_t = int64_t; + +// C++ functions to be exposed to Python +fptr_t init_custom_ag(const std::vector& group_ranks, int64_t nccl_comm_ptr); +void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output, + const std::optional>& sizes); +std::vector all_gather_list( + fptr_t _ag_op, const std::vector& input_list, + const std::optional>& sizes); +void dispose_custom_ag(fptr_t _ag_op); + +} // namespace aphrodite diff --git a/kernels/ops.h b/kernels/ops.h index 4cb90f5594..75a6cf8b88 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -185,6 +185,14 @@ std::tuple allocate_shared_buffer_and_handle( int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); +fptr_t init_custom_ag(const std::vector& group_ranks); +void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output, + const std::optional>& sizes); +std::vector all_gather_list( + fptr_t _ag_op, const std::vector& input_list, + const std::optional>& sizes); +void dispose_custom_ag(fptr_t _ag_op); + void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, diff --git a/kernels/torch_bindings.cpp b/kernels/torch_bindings.cpp index 2536008820..9c07634f9e 100644 --- a/kernels/torch_bindings.cpp +++ b/kernels/torch_bindings.cpp @@ -927,4 +927,21 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { #endif } +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ag), custom_ag) { + // Custom all-gather kernels + custom_ag.def( + "init_custom_ag(int[] group_ranks) -> int"); + custom_ag.impl("init_custom_ag", torch::kCUDA, &init_custom_ag); + + custom_ag.def( + "all_gather(int ag_op, Tensor input, Tensor! output, int[]? sizes) -> ()"); + custom_ag.impl("all_gather", torch::kCUDA, &all_gather); + + custom_ag.def( + "all_gather_list(int ag_op, Tensor[] input_list, int[]? sizes) -> Tensor[]"); + custom_ag.impl("all_gather_list", torch::kCUDA, &all_gather_list); + + custom_ag.def("dispose_custom_ag", &dispose_custom_ag); +} + REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file