Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ set(APHRODITE_EXT_SRC
"kernels/quantization/activation_kernels.cu"
"kernels/cuda_utils_kernels.cu"
"kernels/all_reduce/custom_all_reduce.cu"
"kernels/all_gather/custom_all_gather.cu"
"kernels/torch_bindings.cpp")

if(APHRODITE_GPU_LANG STREQUAL "CUDA")
Expand Down
205 changes: 205 additions & 0 deletions kernels/all_gather/custom_all_gather.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#include "custom_all_gather.cuh"

#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <numeric>

#if ENABLE_MULTI_DEVICE
namespace aphrodite {
namespace runtime {
namespace TorchUtils {
enum class DataType : int32_t {
kFP32,
kFP16,
kBF16,
kINT32,
kINT8,
kUINT8,
kINT64,
kBOOL,
kCOUNT,
};

inline DataType dataType(torch::ScalarType type) {
switch (type) {
case torch::ScalarType::Float: return DataType::kFP32;
case torch::ScalarType::Half: return DataType::kFP16;
case torch::ScalarType::BFloat16: return DataType::kBF16;
case torch::ScalarType::Int: return DataType::kINT32;
case torch::ScalarType::Char: return DataType::kINT8;
case torch::ScalarType::Byte: return DataType::kUINT8;
case torch::ScalarType::Long: return DataType::kINT64;
case torch::ScalarType::Bool: return DataType::kBOOL;
default: TORCH_CHECK(false, "Unsupported data type");
}
}
} // namespace TorchUtils
} // namespace runtime
} // namespace aphrodite

#ifndef TLLM_LOG_TRACE
#define TLLM_LOG_TRACE(...) \
do { \
} while (0)
#endif

#ifndef COMM_SESSION
struct DummyCommSession {
int getRank() const { return 0; }
} COMM_SESSION;
#endif

namespace aphrodite {

namespace {

const std::unordered_map<aphrodite::runtime::TorchUtils::DataType, ncclDataType_t>* getDtypeMap() {
static const std::unordered_map<aphrodite::runtime::TorchUtils::DataType, ncclDataType_t> dtypeMap = {
{aphrodite::runtime::TorchUtils::DataType::kFP32, ncclFloat},
{aphrodite::runtime::TorchUtils::DataType::kFP16, ncclHalf},
{aphrodite::runtime::TorchUtils::DataType::kBF16, ncclBfloat16},
{aphrodite::runtime::TorchUtils::DataType::kINT32, ncclInt32},
{aphrodite::runtime::TorchUtils::DataType::kINT8, ncclInt8},
{aphrodite::runtime::TorchUtils::DataType::kUINT8, ncclUint8},
{aphrodite::runtime::TorchUtils::DataType::kINT64, ncclInt64},
{aphrodite::runtime::TorchUtils::DataType::kBOOL, ncclUint8}, // NCCL doesn't have bool, use uint8
};
return &dtypeMap;
}

class AllgatherOp {
public:
AllgatherOp(std::set<int> group, ncclComm_t comm)
: mGroup(std::move(group)), mNcclComm(comm)
{
}

~AllgatherOp() = default;

void initialize() {
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
// NCCL communicator is now passed in the constructor
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
}

std::vector<torch::Tensor> run_list(torch::TensorList input_list, const std::optional<std::vector<int64_t>>& sizes) {
TORCH_CHECK(mNcclComm != nullptr, "NCCL communicator not initialized.");

bool use_nccl_allgather = !sizes.has_value() ||
std::all_of(sizes.value().begin(), sizes.value().end(),
[&sizes](int64_t size) { return size == sizes.value()[0]; });

int64_t sum_sizes = sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{}) : 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initial value for std::accumulate is 0, which is an int. Since the sizes vector contains int64_t values, the sum could overflow an int if it exceeds INT_MAX. The accumulator's type is determined by the type of this initial value. To prevent overflow, please use an int64_t initial value.

        int64_t sum_sizes = sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), int64_t{0}, std::plus<>{}) : 0;


std::vector<torch::Tensor> output_list;
output_list.reserve(input_list.size());

// NCCLCHECK_THROW(ncclGroupStart()); // Group operations might be managed by Aphrodite's distributed backend
for (auto const& input : input_list) {
TORCH_CHECK(input.is_cuda(), "Input tensor must be on CUDA device.");
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
auto type = aphrodite::runtime::TorchUtils::dataType(input.scalar_type());

std::vector<int64_t> outputShape = input.sizes().vec();
if (sizes.has_value()) {
outputShape[0] = sum_sizes;
} else {
outputShape[0] *= mGroup.size();
}
auto output = torch::empty(outputShape, input.options());

if (use_nccl_allgather) {
AT_CUDA_CHECK(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(), (*getDtypeMap())[type],
mNcclComm, stream));
} else {
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initial value for std::accumulate is 1, which is an int. The product of tensor dimensions can easily overflow an int. The accumulator's type is determined by this initial value. Please use a size_t initial value to prevent potential overflow, as the result is stored in a size_t.

                size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), size_t{1}, std::multiplies<>{});

int64_t split_offset = 0;

for (int root_idx = 0; root_idx < static_cast<int>(mGroup.size()); ++root_idx) {
auto it = mGroup.begin();
std::advance(it, root_idx);
int root_rank = *it;

auto split_size = sizes.value()[root_idx];
AT_CUDA_CHECK(ncclBroadcast(input.data_ptr(),
output.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).mutable_data_ptr(),
numel_base * split_size, (*getDtypeMap())[type], root_rank, mNcclComm, stream));
split_offset += split_size;
}
}
output_list.push_back(output);
}
// NCCLCHECK_THROW(ncclGroupEnd()); // Group operations might be managed by Aphrodite's distributed backend
return output_list;
}

torch::Tensor run(torch::Tensor input, const std::optional<std::vector<int64_t>>& sizes) {
return run_list({input}, sizes)[0];
}

private:
std::set<int> mGroup;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using std::set for mGroup introduces a critical bug and is inefficient. std::set sorts its elements, which will break the correspondence between the ranks and the sizes vector if the input group_ranks is not sorted. This leads to incorrect data gathering when sizes are variable. Additionally, iterating over the set using std::advance inside run_list is inefficient.

Please change mGroup to be a std::vector<int> to preserve the order of ranks and allow for efficient indexing. This requires changes in multiple places:

  1. Change mGroup's type to std::vector<int> here.
  2. Update AllgatherOp constructor (lines 72-74) to accept std::vector<int>.
  3. Update init_custom_ag (lines 148-154) to construct a std::vector<int> from group_ranks and pass it to AllgatherOp.
  4. Update the loop in run_list (lines 118-121) to use mGroup[root_idx] instead of std::advance.
    std::vector<int> mGroup;

ncclComm_t mNcclComm; // Stored NCCL communicator
};

} // namespace

// C++ functions to be exposed to Python
fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks, int64_t nccl_comm_ptr) {
std::set<int> group_set(group_ranks.begin(), group_ranks.end());
ncclComm_t comm = reinterpret_cast<ncclComm_t>(nccl_comm_ptr);
AllgatherOp* op = new AllgatherOp(group_set, comm);
op->initialize();
return reinterpret_cast<fptr_t>(op);
}

void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output,
const std::optional<std::vector<int64_t>>& sizes) {
AllgatherOp* op = reinterpret_cast<AllgatherOp*>(_ag_op);
torch::Tensor result = op->run(input, sizes);
output.copy_(result); // Copy result to the provided output tensor
}

std::vector<torch::Tensor> all_gather_list(
fptr_t _ag_op, const std::vector<torch::Tensor>& input_list,
const std::optional<std::vector<int64_t>>& sizes) {
AllgatherOp* op = reinterpret_cast<AllgatherOp*>(_ag_op);
return op->run_list(input_list, sizes);
}

void dispose_custom_ag(fptr_t _ag_op) {
AllgatherOp* op = reinterpret_cast<AllgatherOp*>(_ag_op);
delete op;
}

} // namespace aphrodite

#else // ENABLE_MULTI_DEVICE

// Dummy implementations for when multi-device is not enabled
namespace aphrodite {

fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks, int64_t nccl_comm_ptr) {
TORCH_CHECK(false, "Multi-device support not enabled.");
return 0;
}

void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output,
const std::optional<std::vector<int64_t>>& sizes) {
TORCH_CHECK(false, "Multi-device support not enabled.");
}

std::vector<torch::Tensor> all_gather_list(
fptr_t _ag_op, const std::vector<torch::Tensor>& input_list,
const std::optional<std::vector<int64_t>>& sizes) {
TORCH_CHECK(false, "Multi-device support not enabled.");
return {};
}

void dispose_custom_ag(fptr_t _ag_op) {
TORCH_CHECK(false, "Multi-device support not enabled.");
}

} // namespace aphrodite

#endif // ENABLE_MULTI_DEVICE
30 changes: 30 additions & 0 deletions kernels/all_gather/custom_all_gather.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <cuda_runtime.h>
#include <torch/all.h>
#include <set>
#include <vector>
#include <memory>

#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif

namespace aphrodite {

// Forward declaration of the AllgatherOp class
class AllgatherOp;

// Define fptr_t for passing C++ object pointers to Python
using fptr_t = int64_t;

// C++ functions to be exposed to Python
fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks, int64_t nccl_comm_ptr);
void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output,
const std::optional<std::vector<int64_t>>& sizes);
std::vector<torch::Tensor> all_gather_list(
fptr_t _ag_op, const std::vector<torch::Tensor>& input_list,
const std::optional<std::vector<int64_t>>& sizes);
void dispose_custom_ag(fptr_t _ag_op);

} // namespace aphrodite
8 changes: 8 additions & 0 deletions kernels/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);

fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of init_custom_ag is inconsistent with its implementation in custom_all_gather.cu and its declaration in custom_all_gather.cuh. The implementation expects a second argument nccl_comm_ptr of type int64_t, which is missing here. This will cause a linker error.

Suggested change
fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks);
fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks, int64_t nccl_comm_ptr);

void all_gather(fptr_t _ag_op, torch::Tensor& input, torch::Tensor& output,
const std::optional<std::vector<int64_t>>& sizes);
std::vector<torch::Tensor> all_gather_list(
fptr_t _ag_op, const std::vector<torch::Tensor>& input_list,
const std::optional<std::vector<int64_t>>& sizes);
void dispose_custom_ag(fptr_t _ag_op);

void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
Expand Down
17 changes: 17 additions & 0 deletions kernels/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,4 +927,21 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
#endif
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ag), custom_ag) {
// Custom all-gather kernels
custom_ag.def(
"init_custom_ag(int[] group_ranks) -> int");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TorchScript signature for init_custom_ag is inconsistent with its C++ implementation. It's missing the nccl_comm_ptr argument. This will cause a compilation or runtime error. The signature should accept an additional integer for the communicator pointer.

Suggested change
"init_custom_ag(int[] group_ranks) -> int");
"init_custom_ag(int[] group_ranks, int nccl_comm_ptr) -> int");

custom_ag.impl("init_custom_ag", torch::kCUDA, &init_custom_ag);

custom_ag.def(
"all_gather(int ag_op, Tensor input, Tensor! output, int[]? sizes) -> ()");
custom_ag.impl("all_gather", torch::kCUDA, &all_gather);

custom_ag.def(
"all_gather_list(int ag_op, Tensor[] input_list, int[]? sizes) -> Tensor[]");
custom_ag.impl("all_gather_list", torch::kCUDA, &all_gather_list);

custom_ag.def("dispose_custom_ag", &dispose_custom_ag);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
Loading