Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions infini_train/include/autograd/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ class Tensor;

namespace infini_train::autograd {

struct LinearGradFlags {
bool input = false;
bool weight = false;
bool bias = false;
};

class Linear : public Function {
public:
static constexpr char kType[] = "LinearFunction";
Expand Down
72 changes: 72 additions & 0 deletions infini_train/include/common/cuda/gemm.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include <cublas_v2.h>
#include <cuda_runtime_api.h>

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train::kernels::cuda {

/**
* Return the cuBLAS handle associated with the given device.
* Shared by linear.cu, matmul.cu, and any future GEMM-using kernels.
*/
cublasHandle_t GetCublasHandle(const Device &device);

/**
* Return the CUDA stream associated with the given device.
* Shared by kernels that need to launch device-side code directly.
*/
cudaStream_t GetCudaStream(const Device &device);

/**
* Parameter bundle for a single GEMM call:
* C = alpha * op(A) * op(B) + beta * C
*
* batch_count == 1 → non-batched path (cublasGemmEx)
* batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx)
*
* When batch_count == 1, stride_a/b/c are unused and must be left at 0.
*/
struct GemmParams {
cublasOperation_t trans_a = CUBLAS_OP_N;
cublasOperation_t trans_b = CUBLAS_OP_N;

int m = 0; // rows of op(A) and C
int n = 0; // cols of op(B) and C
int k = 0; // cols of op(A) == rows of op(B)

const void *A = nullptr;
int lda = 0;
const void *B = nullptr;
int ldb = 0;
void *C = nullptr;
int ldc = 0;

float alpha = 1.0f;
float beta = 0.0f;

// batch_count=1: non-batched (Linear path); stride_a/b/c must be 0
// batch_count>1: strided-batched (Matmul path)
int batch_count = 1;
long long stride_a = 0;
long long stride_b = 0;
long long stride_c = 0;

DataType input_dtype; // dtype of A and B
DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out)

cublasHandle_t blas_handle = nullptr;
};

/**
* Execute the GEMM described by `p` via cuBLAS.
* Dispatches to cublasGemmEx (batch_count==1) or
* cublasGemmStridedBatchedEx (batch_count>1).
* Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision.
* Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)).
*/
void GemmCuda(const GemmParams &p);

} // namespace infini_train::kernels::cuda
30 changes: 21 additions & 9 deletions infini_train/src/autograd/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,29 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward";
LinearGradFlags grad_flags = {.input = needs_input_grad_[0],
.weight = needs_input_grad_.size() > 1 && needs_input_grad_[1],
.bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
bool need_grad_input = needs_input_grad_[0];
bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2];

auto device = grad_output->GetDevice().type();
// TODO: skip autograd graph construction entirely when no input requires grad
auto [grad_input, grad_weight, grad_bias]
= Dispatcher::Instance()
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_,
grad_output, bias_, grad_flags);

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (need_grad_input) {
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_);
}
if (need_grad_weight) {
grad_weight = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_);
}
if (need_grad_bias) {
grad_bias = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "LinearBackwardBias"}, grad_output,
out_features_);
}

if (bias_) {
return {grad_input, grad_weight, grad_bias};
} else {
Expand Down
35 changes: 28 additions & 7 deletions infini_train/src/autograd/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
// determined by autocast, not derived from output->Dtype().
auto compute_dtype = output->Dtype();
saved_tensors_ = {
input1->Dtype() == compute_dtype ? input1 : std::make_shared<Tensor>(input1->To(compute_dtype)),
input2->Dtype() == compute_dtype ? input2 : std::make_shared<Tensor>(input2->To(compute_dtype)),

// grad_input1 = grad_output @ input2^T, so input2 is needed
// grad_input2 = grad_output^T @ input1, so input1 is needed
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto cast = [&](const std::shared_ptr<Tensor> &t) {
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
};

saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
out_features_ = output->Dims()[0];
}

Expand All @@ -45,10 +52,24 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Matmul::Backward";
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto device = input1->GetDevice().type();
auto [grad_input1, grad_input2]
= Dispatcher::Instance().Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "MatmulBackward"}, input1, input2, grad_output);
return {grad_input1, grad_input2};

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_other = nullptr;

if (need_grad_input1) {
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput"}, input2,
grad_output, input1->Dims());
}
if (need_grad_input2) {
grad_other = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardOther"}, input1,
grad_output, input2->Dims());
}

return {grad_input, grad_other};
}
} // namespace infini_train::autograd
179 changes: 40 additions & 139 deletions infini_train/src/kernels/cpu/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,103 +5,10 @@

#include "glog/logging.h"

#include "infini_train/include/autograd/linear.h"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

namespace infini_train::kernels::cpu {
std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other) {
/*
output[*, m, n] = input[*, m, k] * other[*, k, n]
*/
// TODO(dcj): support broadcast later
const auto &input_dims = input->Dims();
const auto &other_dims = other->Dims();

CHECK_GE(input_dims.size(), 2);
CHECK_GE(other_dims.size(), 2);
CHECK_EQ(input_dims.size(), other_dims.size());

const int64_t m = input_dims[input_dims.size() - 2];
const int64_t k = input_dims[input_dims.size() - 1];
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
const int64_t n = other_dims[other_dims.size() - 1];

const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
}

std::vector<int64_t> output_dims = input_dims;
output_dims[output_dims.size() - 1] = n;
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);

for (int64_t b = 0; b < bs; ++b) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
float acc = 0.0f;
for (int64_t p = 0; p < k; ++p) {
acc += static_cast<const float *>(input->DataPtr())[b * m * k + i * k + p]
* static_cast<const float *>(other->DataPtr())[b * k * n + p * n + j];
}
static_cast<float *>(output->DataPtr())[b * m * n + i * n + j] = acc;
}
}
}
return {output};
}

std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
const std::shared_ptr<Tensor> &grad_output) {
/*
grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
*/
const auto &input_dims = input->Dims();
const auto &other_dims = other->Dims();
const auto &grad_output_dims = grad_output->Dims();

CHECK_GE(input_dims.size(), 2);
CHECK_EQ(input_dims.size(), other_dims.size());
CHECK_EQ(input_dims.size(), grad_output_dims.size());

const int64_t m = input_dims[input_dims.size() - 2];
const int64_t k = input_dims[input_dims.size() - 1];
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
const int64_t n = other_dims[other_dims.size() - 1];

CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]);

const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match";
}

auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32);
grad_input->Fill<float>(0.0f);
grad_other->Fill<float>(0.0f);

for (int64_t b = 0; b < bs; ++b) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
for (int64_t p = 0; p < k; ++p) {
const auto input_idx = b * m * k + i * k + p;
const auto other_idx = b * k * n + p * n + j;
static_cast<float *>(grad_input->DataPtr())[input_idx]
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
static_cast<float *>(grad_other->DataPtr())[other_idx]
+= grad * static_cast<const float *>(input->DataPtr())[input_idx];
}
}
}
}
return {grad_input, grad_other};
}

std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
bool transpose, const std::shared_ptr<Tensor> &bias) {
Expand Down Expand Up @@ -146,71 +53,65 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
return output;
}

// TODO(dcj): support linear without bias later
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
int64_t in_features, int64_t out_features, const std::vector<int64_t> &input_dims,
const std::shared_ptr<Tensor> &grad_output, bool bias,
infini_train::autograd::LinearGradFlags grad_flags) {
std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weight,
const std::shared_ptr<Tensor> &grad_output, bool transpose,
int64_t in_features, int64_t out_features,
const std::vector<int64_t> &input_dims) {
/*
transpose: grad_input = grad_output * weight
grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)

!transpose: grad_input = grad_output * weight^T
grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/
const auto compute_grad_input = grad_flags.input;
const auto compute_grad_weight = grad_flags.weight;
const auto compute_grad_bias = grad_flags.bias;

CHECK_GE(input_dims.size(), 2);

std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (compute_grad_input) {
CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)";
grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
}
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
}
return grad_input;
}

if (compute_grad_weight) {
CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)";
grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
if (transpose) {
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
}
std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &input,
const std::shared_ptr<Tensor> &grad_output, bool transpose,
int64_t in_features, int64_t out_features) {
/*
transpose:
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]

if (compute_grad_bias && bias) {
grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
!transpose:
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
*/
std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
if (transpose) {
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
return grad_weight;
}

return {grad_input, grad_weight, grad_bias};
std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_output, int64_t out_features) {
/*
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/
auto grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
return grad_bias;
}

} // namespace infini_train::kernels::cpu

#define REGISTER_CPU_LINEAR_KERNEL(kernel_name) \
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)

REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
REGISTER_CPU_LINEAR_KERNEL(LinearForward)
REGISTER_CPU_LINEAR_KERNEL(LinearBackward)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardBias)

#undef REGISTER_CPU_LINEAR_KERNEL
Loading
Loading