diff --git a/.gitmodules b/.gitmodules index 470cf466..578e24f9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,12 @@ +[submodule "third_party/googletest"] + path = third_party/googletest + url = https://github.com/google/googletest.git [submodule "third_party/glog"] path = third_party/glog - url = git@github.com:google/glog.git + url = https://github.com/google/glog.git [submodule "third_party/gflags"] path = third_party/gflags - url = git@github.com:gflags/gflags.git + url = https://github.com/gflags/gflags.git [submodule "third_party/eigen"] path = third_party/eigen - url = git@github.com:InfiniTensor/eigen-mirror.git + url = https://github.com/eigenteam/eigen-git-mirror.git diff --git a/CMakeLists.txt b/CMakeLists.txt index df636b27..22dcf791 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ option(USE_CUDA "Support NVIDIA CUDA" OFF) option(PROFILE_MODE "ENABLE PROFILE MODE" OFF) option(USE_OMP "Use OpenMP as backend for Eigen" ON) option(USE_NCCL "Build project for distributed running" ON) +option(BUILD_TEST "Build InfiniTrain tests" OFF) project(infini_train VERSION 0.5.0 LANGUAGES CXX) @@ -14,6 +15,19 @@ set(CMAKE_CXX_EXTENSIONS OFF) # Generate compile_commands.json set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +# ------------------------------------------------------------------------------ +# GoogleTest (submodule) +# ------------------------------------------------------------------------------ +if(BUILD_TEST) + if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/third_party/googletest/CMakeLists.txt) + message(FATAL_ERROR "googletest submodule not found at third_party/googletest. " + "Run: git submodule update --init third_party/googletest") + endif() + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + add_subdirectory(third_party/googletest) + enable_testing() +endif() + # ------------------------------------------------------------------------------ # Third-party deps # ------------------------------------------------------------------------------ @@ -26,7 +40,9 @@ include_directories(${gflags_SOURCE_DIR}/include) set(WITH_GFLAGS OFF CACHE BOOL "Disable glog finding system gflags" FORCE) set(WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE) add_subdirectory(third_party/glog) +# add_compile_definitions(GLOG_USE_GLOG_EXPORT=1) include_directories(${glog_SOURCE_DIR}/src) +# include_directories(${glog_BINARY_DIR}/glog) # eigen if(USE_OMP) @@ -48,6 +64,10 @@ endif() # Framework core sources (*.cc), excluding cpu kernels (they are built separately) file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc) list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*") +if(NOT USE_CUDA) + list(FILTER SRC EXCLUDE REGEX ".*runtime/cuda/.*") + list(FILTER SRC EXCLUDE REGEX ".*ccl/cuda/.*") +endif() if(NOT USE_NCCL) list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*") endif() @@ -190,17 +210,8 @@ add_executable(llama3 ) link_infini_train_exe(llama3) -# Tools -add_subdirectory(tools/infini_run) -set_target_properties(infini_run PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) # Tests -add_executable(test_hook test/hook/test_hook.cc) -link_infini_train_exe(test_hook) - -add_executable(test_precision_check test/hook/test_precision_check.cc) -link_infini_train_exe(test_precision_check) - -add_executable(test_lora test/lora/test_lora.cc) -link_infini_train_exe(test_lora) - +if(BUILD_TEST) + add_subdirectory(tests) +endif() diff --git a/infini_train/include/autograd/function.h b/infini_train/include/autograd/function.h index 651b6bf3..c7221d71 100644 --- a/infini_train/include/autograd/function.h +++ b/infini_train/include/autograd/function.h @@ -47,6 +47,7 @@ class Function : public std::enable_shared_from_this { protected: std::vector> saved_tensors_; + std::vector needs_input_grad_; private: std::vector, int>> next_functions_; diff --git a/infini_train/include/autograd/linear.h b/infini_train/include/autograd/linear.h index d043d2f4..21d107b9 100644 --- a/infini_train/include/autograd/linear.h +++ b/infini_train/include/autograd/linear.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -10,6 +11,13 @@ class Tensor; } namespace infini_train::autograd { + +struct LinearGradFlags { + bool input = false; + bool weight = false; + bool bias = false; +}; + class Linear : public Function { public: static constexpr char kType[] = "LinearFunction"; @@ -22,7 +30,10 @@ class Linear : public Function { std::vector> Backward(const std::vector> &grad_outputs) override; private: + bool transpose_ = false; + bool bias_ = false; + int64_t in_features_ = 0; int64_t out_features_ = 0; - bool bias_ = true; + std::vector input_dims_; }; } // namespace infini_train::autograd diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index a40d0987..39d6dd46 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -138,6 +138,10 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr View(const std::vector &dims); std::shared_ptr Contiguous(); + // FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor + // class before this can be implemented correctly. The guard in elementwise.cu ensures + // non-contiguous tensors fall back to the broadcast path until this is resolved. + bool IsContiguous() const; std::shared_ptr Flatten(int64_t start = 0, int64_t end = -1); std::shared_ptr Squeeze(int64_t dim); std::shared_ptr Unsqueeze(int64_t dim); diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index 37846bbf..d9b70bc1 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -26,6 +26,13 @@ AccumulateGrad::Backward(const std::vector> &grad_output core::DeviceGuard guard(device); if (grad_output) { + if (grad_output->Dtype() != tensor_->Dtype()) { + LOG(WARNING) << "AccumulateGrad: grad dtype (" << kDataTypeToDesc.at(grad_output->Dtype()) + << ") does not match parameter dtype (" << kDataTypeToDesc.at(tensor_->Dtype()) + << "). This indicates a dtype mismatch in the autograd graph (e.g. autocast " + "running before autograd). The grad is not cast and will be used as-is."; + } + if (grad) { if (tensor_->ConsumeGradOverwriteFlag()) { // If the tensor is marked to overrite its current grad on next grad update diff --git a/infini_train/src/autograd/elementwise.cc b/infini_train/src/autograd/elementwise.cc index 7291e284..655cd309 100644 --- a/infini_train/src/autograd/elementwise.cc +++ b/infini_train/src/autograd/elementwise.cc @@ -390,6 +390,11 @@ std::vector> Add::Backward(const std::vectorGetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "AddBackward"}, grad_output, a_dims_, b_dims_); diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 42a95729..3d741908 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -36,6 +36,16 @@ std::vector> Function::Apply(const std::vectorrequires_grad(); + } + } + std::vector> output_tensors; { autograd::NoGradGuard no_grad; @@ -129,6 +139,7 @@ void Function::BackwardPartial(const std::shared_ptr &grad_output, int g saved_tensors_.clear(); grad_outputs_.clear(); + needs_input_grad_.clear(); grad_outputs_reached_ = 0; dependencies_reached_ = 0; diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index be397c32..c9ed1dbb 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -17,12 +17,35 @@ std::vector> Linear::Forward(const std::vector> &input_tensors, - const std::vector> &) { + const std::vector> &output_tensors) { const auto &input = input_tensors[0]; const auto &weight = input_tensors[1]; - saved_tensors_ = {input, weight}; + // Cast saved tensors to forward compute dtype (output dtype) so backward + // computes in the same precision as forward, matching PyTorch's behavior. + + // FIXME: An extra cast (input/weight -> compute_dtype) is performed here because + // autocast runs before autograd. The correct approach is to adjust the ordering or + // integration of autocast and autograd so that autograd receives already-cast tensors, + // avoiding the redundant cast. + + // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be + // determined by autocast, not derived from output_tensors[0]->Dtype(). + auto compute_dtype = output_tensors[0]->Dtype(); + bool need_input = needs_input_grad_.size() > 0 && needs_input_grad_[0]; + bool need_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + + auto cast = [&](const std::shared_ptr &t) { + return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); + }; + + // grad_input needs weight, grad_weight needs input + saved_tensors_ = {need_weight ? cast(input) : nullptr, need_input ? cast(weight) : nullptr}; + + transpose_ = true; bias_ = input_tensors.size() == 3; + in_features_ = weight->Dims()[1]; out_features_ = weight->Dims()[0]; + input_dims_ = input->Dims(); } std::vector> Linear::Backward(const std::vector> &grad_outputs) { @@ -32,13 +55,22 @@ std::vector> Linear::Backward(const std::vectorGetDevice().type(); + CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward"; + LinearGradFlags grad_flags = {.input = needs_input_grad_[0], + .weight = needs_input_grad_.size() > 1 && needs_input_grad_[1], + .bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]}; + + auto device = grad_output->GetDevice().type(); + // TODO: skip autograd graph construction entirely when no input requires grad auto [grad_input, grad_weight, grad_bias] = Dispatcher::Instance() .Call, std::shared_ptr, std::shared_ptr>>( - {device, "LinearBackward"}, input, weight, true, out_features_, grad_output, bias_); - return bias_ ? std::vector>{grad_input, grad_weight, grad_bias} - : std::vector>{grad_input, grad_weight}; - ; + {device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_, + grad_output, bias_, grad_flags); + if (bias_) { + return {grad_input, grad_weight, grad_bias}; + } else { + return {grad_input, grad_weight}; + } } } // namespace infini_train::autograd diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 335396d6..49f593bf 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -20,7 +20,21 @@ void Matmul::SetupContext(const std::vector> &input_tens const auto &input1 = input_tensors[0]; const auto &input2 = input_tensors[1]; const auto &output = output_tensors[0]; - saved_tensors_ = {input1, input2}; + // Cast saved tensors to forward compute dtype (output dtype) so backward + // computes in the same precision as forward, matching PyTorch's behavior. + + // FIXME: An extra cast (input1/input2 -> compute_dtype) is performed here because + // autocast runs before autograd. The correct approach is to adjust the ordering or + // integration of autocast and autograd so that autograd receives already-cast tensors, + // avoiding the redundant cast. + + // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be + // determined by autocast, not derived from output->Dtype(). + auto compute_dtype = output->Dtype(); + saved_tensors_ = { + input1->Dtype() == compute_dtype ? input1 : std::make_shared(input1->To(compute_dtype)), + input2->Dtype() == compute_dtype ? input2 : std::make_shared(input2->To(compute_dtype)), + }; out_features_ = output->Dims()[0]; } diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index e7de3fa9..2b209417 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -1,11 +1,11 @@ #include -#include #include #include #include #include "glog/logging.h" +#include "infini_train/include/autograd/linear.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -70,6 +70,7 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptr LinearForward(const std::shared_ptr &input, cons // TODO(dcj): support linear without bias later std::tuple, std::shared_ptr, std::shared_ptr> LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, - int64_t out_features, const std::shared_ptr &grad_output, const bool bias) { + int64_t in_features, int64_t out_features, const std::vector &input_dims, + const std::shared_ptr &grad_output, bool bias, + infini_train::autograd::LinearGradFlags grad_flags) { /* transpose: grad_input = grad_output * weight grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features] @@ -160,32 +163,41 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptrDims(); CHECK_GE(input_dims.size(), 2); - const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies{}); - const int64_t in_features = *input_dims.rbegin(); - const auto &weight_dims = weight->Dims(); - CHECK_EQ(weight_dims.size(), 2); - CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]); - CHECK_EQ(out_features, weight_dims[transpose ? 0 : 1]); + std::vector weight_dims + = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; - auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); - auto grad_weight = std::make_shared(weight_dims, DataType::kFLOAT32); + std::shared_ptr grad_input = nullptr; + std::shared_ptr grad_weight = nullptr; std::shared_ptr grad_bias = nullptr; - if (bias) { - grad_bias = std::make_shared(std::vector{out_features}, DataType::kFLOAT32); + + if (compute_grad_input) { + CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; + grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + if (transpose) { + grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix(); + } else { + grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose(); + } } - if (transpose) { - grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix(); - grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix(); - } else { - grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose(); - grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix(); + if (compute_grad_weight) { + CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; + grad_weight = std::make_shared(weight_dims, DataType::kFLOAT32); + if (transpose) { + grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix(); + } else { + grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix(); + } } - if (bias) { + + if (compute_grad_bias && bias) { + grad_bias = std::make_shared(std::vector{out_features}, DataType::kFLOAT32); grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum(); } diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 6b356156..01e1048f 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -15,6 +15,57 @@ namespace { using namespace infini_train::common::cuda; constexpr int kWarpSize = 32; +// Aligned vector type for vectorized loads/stores (128-bit). +template struct __align__(sizeof(T) * N) aligned_vector { T val[N]; }; + +// Elements per vectorized load/store: 128-bit / sizeof(T). +// float → 4, bf16/half → 8, double → 2. +template constexpr int kVecSize = 16 / sizeof(T); + +// Maximum number of dimensions supported by the broadcast metadata. +// Real-world tensors in this codebase top out at 4-5 dims, so 8 leaves comfortable headroom +// while keeping the struct under the 4 KB CUDA kernel parameter limit. +constexpr int kMaxBroadcastDims = 8; + +// POD metadata for broadcast kernels. Passed by value into __global__ kernels so the data +// lives in CUDA kernel parameter memory (constant cache) instead of being uploaded via a +// per-call cudaMallocAsync + cudaMemcpyAsync into global memory. +struct BroadcastMeta { + int ndim; + int64_t a_strides[kMaxBroadcastDims]; + int64_t b_strides[kMaxBroadcastDims]; + int64_t out_strides[kMaxBroadcastDims]; + int64_t a_shape[kMaxBroadcastDims]; + int64_t b_shape[kMaxBroadcastDims]; +}; + +// Build a BroadcastMeta on the host from input/output dim vectors. Right-aligns a_dims/b_dims +// to out_dims's rank (the broadcasting convention) and computes contiguous strides for each. +inline BroadcastMeta MakeBroadcastMeta(const std::vector &a_dims, const std::vector &b_dims, + const std::vector &out_dims) { + BroadcastMeta m{}; + const int ndim = static_cast(out_dims.size()); + CHECK_LE(ndim, kMaxBroadcastDims) << "Broadcast ndim exceeds kMaxBroadcastDims (" << kMaxBroadcastDims << ")"; + m.ndim = ndim; + + std::vector a_shape(ndim, 1), b_shape(ndim, 1); + std::copy_backward(a_dims.begin(), a_dims.end(), a_shape.end()); + std::copy_backward(b_dims.begin(), b_dims.end(), b_shape.end()); + + auto a_str = ComputeStrides(a_shape); + auto b_str = ComputeStrides(b_shape); + auto out_str = ComputeStrides(out_dims); + + for (int i = 0; i < ndim; ++i) { + m.a_strides[i] = a_str[i]; + m.b_strides[i] = b_str[i]; + m.out_strides[i] = out_str[i]; + m.a_shape[i] = a_shape[i]; + m.b_shape[i] = b_shape[i]; + } + return m; +} + template __global__ void UnaryForwardKernel(T *output, Func fn, size_t num_elements, size_t offset, const T *input) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; @@ -36,21 +87,124 @@ __device__ inline int64_t CalcOffset(int64_t idx, int ndim, const int64_t *strid return offset; } +inline bool ShapesEqual(const std::vector &a, const std::vector &b) { + if (a.size() != b.size()) { + return false; + } + for (size_t i = 0; i < a.size(); ++i) { + if (a[i] != b[i]) { + return false; + } + } + return true; +} + template -__global__ void BinaryForwardKernel(T *output, Func fn, int ndim, const int64_t *a_strides, const int64_t *a_shape, - const int64_t *b_strides, const int64_t *b_shape, const int64_t *out_strides, - const T *a, const T *b, size_t num_elements) { +__global__ void BinaryForwardKernel(T *output, Func fn, BroadcastMeta meta, const T *a, const T *b, + size_t num_elements) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_elements) { return; } - int64_t a_offset = CalcOffset(idx, ndim, a_strides, a_shape, out_strides); - int64_t b_offset = CalcOffset(idx, ndim, b_strides, b_shape, out_strides); + int64_t a_offset = CalcOffset(idx, meta.ndim, meta.a_strides, meta.a_shape, meta.out_strides); + int64_t b_offset = CalcOffset(idx, meta.ndim, meta.b_strides, meta.b_shape, meta.out_strides); output[idx] = fn(a[a_offset], b[b_offset]); } +// Fast path: no broadcast, contiguous tensors — skip CalcOffset entirely +template +__global__ void BinaryForwardKernelNoBroadcast(T *__restrict__ output, Func fn, const T *__restrict__ a, + const T *__restrict__ b, size_t num_elements) { + const size_t grid_stride = static_cast(gridDim.x) * blockDim.x; + for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < num_elements; + idx += grid_stride) { + output[idx] = fn(a[idx], b[idx]); + } +} + +// Fast path backward: no broadcast, contiguous — skip CalcOffset entirely +template +__global__ void BinaryBackwardKernelNoBroadcastFast(T *__restrict__ outA, T *__restrict__ outB, FuncA fn_a, FuncB fn_b, + size_t numel, const T *__restrict__ grad_out, + const T *__restrict__ inA, const T *__restrict__ inB) { + const size_t grid_stride = static_cast(gridDim.x) * blockDim.x; + for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < numel; idx += grid_stride) { + const T a = inA ? inA[idx] : T(0); + const T b = inB ? inB[idx] : T(0); + outA[idx] = Mul(grad_out[idx], fn_a(a, b)); + outB[idx] = Mul(grad_out[idx], fn_b(a, b)); + } +} + +// Vectorized fast path backward: no broadcast, contiguous. +// Each thread processes VecSize elements using 128-bit loads/stores. +template +__global__ void BinaryBackwardKernelNoBroadcastVectorized(T *__restrict__ outA, T *__restrict__ outB, FuncA fn_a, + FuncB fn_b, size_t numel, const T *__restrict__ grad_out, + const T *__restrict__ inA, const T *__restrict__ inB) { + using VecT = aligned_vector; + const size_t num_vecs = numel / VecSize; + const size_t grid_stride = static_cast(gridDim.x) * blockDim.x; + + for (size_t vid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; vid < num_vecs; vid += grid_stride) { + const size_t base = vid * VecSize; + + // 128-bit vectorized loads + VecT g_vec = *reinterpret_cast(&grad_out[base]); + VecT a_vec, b_vec; + if (inA) { + a_vec = *reinterpret_cast(&inA[base]); + } else { +#pragma unroll + for (int i = 0; i < VecSize; ++i) { a_vec.val[i] = T(0); } + } + if (inB) { + b_vec = *reinterpret_cast(&inB[base]); + } else { +#pragma unroll + for (int i = 0; i < VecSize; ++i) { b_vec.val[i] = T(0); } + } + + // Element-wise computation + VecT outA_vec, outB_vec; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + outA_vec.val[i] = Mul(g_vec.val[i], fn_a(a_vec.val[i], b_vec.val[i])); + outB_vec.val[i] = Mul(g_vec.val[i], fn_b(a_vec.val[i], b_vec.val[i])); + } + + // 128-bit vectorized stores + *reinterpret_cast(&outA[base]) = outA_vec; + *reinterpret_cast(&outB[base]) = outB_vec; + } + + // Handle tail elements (numel % VecSize != 0) + const size_t tail_start = num_vecs * VecSize; + for (size_t idx = tail_start + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < numel; + idx += grid_stride) { + const T a = inA ? inA[idx] : T(0); + const T b = inB ? inB[idx] : T(0); + outA[idx] = Mul(grad_out[idx], fn_a(a, b)); + outB[idx] = Mul(grad_out[idx], fn_b(a, b)); + } +} + +// Helper to choose optimal block size based on tensor size +inline size_t ChooseBlockSize(size_t num_elements) { + if (num_elements < 1024) { + return 64; + } + if (num_elements < 65536) { + return 128; + } + if (num_elements < 1048576) { + return 256; + } + return 512; +} + // launch the given kernel function with the given output and inputs template void LaunchKernel(Kernel &&kernel, const std::shared_ptr &output, const Inputs &...inputs) { @@ -59,7 +213,9 @@ void LaunchKernel(Kernel &&kernel, const std::shared_ptr &output, const auto input_ptrs = extract_ptrs(inputs...); const size_t num_elements = output->NumElements(); - dim3 block_dims(std::min(BLOCK_SIZE, static_cast(1024))); + // Use dynamic block size based on tensor size for better occupancy + size_t block_size = std::min(ChooseBlockSize(num_elements), static_cast(1024)); + dim3 block_dims(block_size); dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x)); const size_t step = grid_dims.x * block_dims.x; @@ -95,46 +251,32 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input const auto &a_dims = input_a->Dims(); const auto &b_dims = input_b->Dims(); const auto &out_dims = output->Dims(); - int ndim = out_dims.size(); - std::vector a_shape(ndim, 1), b_shape(ndim, 1), out_shape(ndim, 1); - std::copy_backward(a_dims.begin(), a_dims.end(), a_shape.end()); - std::copy_backward(b_dims.begin(), b_dims.end(), b_shape.end()); - std::copy_backward(out_dims.begin(), out_dims.end(), out_shape.end()); + // Fast path: no broadcast, contiguous — skip cudaMalloc/Memcpy/CalcOffset. + // The IsContiguous() guards ensure non-contiguous tensors fall back to the broadcast + // path, keeping the fast path correct when non-contiguous support is added later. + if (ShapesEqual(a_dims, out_dims) && ShapesEqual(b_dims, out_dims) && input_a->IsContiguous() + && input_b->IsContiguous()) { + const size_t num_elements = output->NumElements(); + const T *a_ptr = static_cast(input_a->DataPtr()); + const T *b_ptr = static_cast(input_b->DataPtr()); + dim3 block_dims(std::min(BLOCK_SIZE, static_cast(1024))); + dim3 grid_dims(std::min(CEIL_DIV(num_elements, block_dims.x), static_cast(65535))); + BinaryForwardKernelNoBroadcast<<>>(output_ptr, func, a_ptr, b_ptr, + num_elements); + } else { + // Broadcast path: pass strides/shapes by value via kernel parameter memory. + // This avoids the per-call cudaMallocAsync/cudaMemcpyAsync/cudaFreeAsync that previously + // dominated the host-side jitter floor (especially under LoRA training). + BroadcastMeta meta = MakeBroadcastMeta(a_dims, b_dims, out_dims); - auto a_stride_host = ComputeStrides(a_shape); - auto b_stride_host = ComputeStrides(b_shape); - auto out_stride_host = ComputeStrides(out_shape); - - int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), cuda_stream); - - int64_t *device_a_strides, *device_b_strides, *device_out_strides, *device_a_shape, *device_b_shape; - device_a_strides = device_buffer + ndim * 0; - device_b_strides = device_buffer + ndim * 1; - device_out_strides = device_buffer + ndim * 2; - device_a_shape = device_buffer + ndim * 3; - device_b_shape = device_buffer + ndim * 4; - - std::vector host_buffer; - host_buffer.insert(host_buffer.end(), a_stride_host.begin(), a_stride_host.end()); - host_buffer.insert(host_buffer.end(), b_stride_host.begin(), b_stride_host.end()); - host_buffer.insert(host_buffer.end(), out_stride_host.begin(), out_stride_host.end()); - host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); - host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - - cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, - cuda_stream); - - LaunchKernel( - [&](dim3 grid, dim3 block, size_t offset, const T *a_ptr, const T *b_ptr) { - BinaryForwardKernel<<>>( - output_ptr, func, ndim, device_a_strides, device_a_shape, device_b_strides, device_b_shape, - device_out_strides, a_ptr, b_ptr, output->NumElements()); - }, - output, inputs...); - - cudaFreeAsync(device_buffer, cuda_stream); + LaunchKernel( + [&](dim3 grid, dim3 block, size_t /*offset*/, const T *a_ptr, const T *b_ptr) { + BinaryForwardKernel<<>>(output_ptr, func, meta, a_ptr, b_ptr, + output->NumElements()); + }, + output, inputs...); + } } else { static_assert(sizeof...(inputs) == 1 || sizeof...(inputs) == 2, "LaunchForward currently only supports unary and binary operations."); @@ -154,18 +296,6 @@ __global__ void UnaryBackwardKernel(T *output, Func fn, size_t num_elements, siz enum class BF16Path { NoBroadcast, TwoPassHist, BlockReduce }; -inline bool ShapesEqual(const std::vector &a, const std::vector &b) { - if (a.size() != b.size()) { - return false; - } - for (size_t i = 0; i < a.size(); ++i) { - if (a[i] != b[i]) { - return false; - } - } - return true; -} - // Lightweight and stable selector for bf16/half execution paths. inline BF16Path DecideBF16Path(const std::vector &b_shape, const std::vector &out_shape, size_t b_num_elements) { @@ -183,16 +313,13 @@ inline BF16Path DecideBF16Path(const std::vector &b_shape, const std::v // Each B element is used exactly once, so gradients can be written directly without reduction. template -__global__ void -BinaryBackwardKernelNoBroadcast(T *__restrict__ outA, T *__restrict__ outB, FuncA fn_a, FuncB fn_b, int ndim, - size_t numel, const int64_t *__restrict__ a_strides, - const int64_t *__restrict__ a_shape, const int64_t *__restrict__ b_strides, - const int64_t *__restrict__ b_shape, const int64_t *__restrict__ out_strides, - const T *__restrict__ grad_out, const T *__restrict__ inA, const T *__restrict__ inB) { +__global__ void BinaryBackwardKernelNoBroadcast(T *__restrict__ outA, T *__restrict__ outB, FuncA fn_a, FuncB fn_b, + BroadcastMeta meta, size_t numel, const T *__restrict__ grad_out, + const T *__restrict__ inA, const T *__restrict__ inB) { const size_t grid_stride = static_cast(gridDim.x) * blockDim.x; for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < numel; idx += grid_stride) { - const int64_t a_off = CalcOffset(idx, ndim, a_strides, a_shape, out_strides); - const int64_t b_off = CalcOffset(idx, ndim, b_strides, b_shape, out_strides); + const int64_t a_off = CalcOffset(idx, meta.ndim, meta.a_strides, meta.a_shape, meta.out_strides); + const int64_t b_off = CalcOffset(idx, meta.ndim, meta.b_strides, meta.b_shape, meta.out_strides); const T a = inA ? inA[a_off] : T(0); const T b = inB ? inB[b_off] : T(0); @@ -207,12 +334,9 @@ BinaryBackwardKernelNoBroadcast(T *__restrict__ outA, T *__restrict__ outB, Func // First pass of histogram two-pass strategy: per-block accumulation in shared memory. template -__global__ void -BinaryBackwardBhistPass1Kernel(T *__restrict__ outA, float *__restrict__ work, FuncA fn_a, FuncB fn_b, int ndim, - size_t numel, int K, const int64_t *__restrict__ a_strides, - const int64_t *__restrict__ a_shape, const int64_t *__restrict__ b_strides, - const int64_t *__restrict__ b_shape, const int64_t *__restrict__ out_strides, - const T *__restrict__ grad_out, const T *__restrict__ inA, const T *__restrict__ inB) { +__global__ void BinaryBackwardBhistPass1Kernel(T *__restrict__ outA, float *__restrict__ work, FuncA fn_a, FuncB fn_b, + BroadcastMeta meta, size_t numel, int K, const T *__restrict__ grad_out, + const T *__restrict__ inA, const T *__restrict__ inB) { extern __shared__ float s_hist[]; // dynamic shared memory: K bins plus padding for every 32 buckets const int pad = K >> 5; // insert one padding slot for every 32 buckets const int hist_len = K + pad; @@ -224,12 +348,12 @@ BinaryBackwardBhistPass1Kernel(T *__restrict__ outA, float *__restrict__ work, F const size_t total_threads = (size_t)gridDim.x * blockDim.x; for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < numel; idx += total_threads) { // Linearized offset for B under general broadcasting. - const int64_t b_off = CalcOffset(idx, ndim, b_strides, b_shape, out_strides); + const int64_t b_off = CalcOffset(idx, meta.ndim, meta.b_strides, meta.b_shape, meta.out_strides); const int bin = static_cast(b_off); // assume K fits in a 32-bit int const int pbin = bin + (bin >> 5); // apply padding mapping // Compute the offset for A under broadcasting. - const int64_t a_off = CalcOffset(idx, ndim, a_strides, a_shape, out_strides); + const int64_t a_off = CalcOffset(idx, meta.ndim, meta.a_strides, meta.a_shape, meta.out_strides); const T a = inA ? inA[a_off] : T(0); const T b = inB ? inB[bin] : T(0); // B is indexed via the flattened bin @@ -293,10 +417,8 @@ __global__ void BinaryBackwardBhistPass2Reduce1D(const float *__restrict__ work, // Helper that materializes the two-pass histogram path for bf16/half B gradients. template -void BinaryBackwardBhistLaunch(FuncA fn_a, FuncB fn_b, T *outA, T *outB, const T *grad_out, int ndim, size_t numel, - int K, const int64_t *a_strides, const int64_t *a_shape, const int64_t *b_strides, - const int64_t *b_shape, const int64_t *out_strides, const T *inA, const T *inB, - cudaStream_t stream) { +void BinaryBackwardBhistLaunch(FuncA fn_a, FuncB fn_b, T *outA, T *outB, const T *grad_out, const BroadcastMeta &meta, + size_t numel, int K, const T *inA, const T *inB, cudaStream_t stream) { const int kBlockSize = 256; int grid = static_cast((numel + kBlockSize - 1) / kBlockSize); if (grid < 1) { @@ -310,8 +432,7 @@ void BinaryBackwardBhistLaunch(FuncA fn_a, FuncB fn_b, T *outA, T *outB, const T // Pass 1: per-block histogram accumulation. const size_t smem_bytes = static_cast(K + (K >> 5)) * sizeof(float); BinaryBackwardBhistPass1Kernel - <<>>(outA, work, fn_a, fn_b, ndim, numel, K, a_strides, a_shape, - b_strides, b_shape, out_strides, grad_out, inA, inB); + <<>>(outA, work, fn_a, fn_b, meta, numel, K, grad_out, inA, inB); CUDA_CHECK(cudaGetLastError()); // Pass 2: choose between 1D and 2D reductions depending on workload shape. @@ -360,10 +481,8 @@ void BinaryBackwardBhistLaunch(FuncA fn_a, FuncB fn_b, T *outA, T *outB, const T // Backward kernel for binary operators // TODO(lzm): determining and passing b_is_broadcasted from the caller; optimize further template -__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, int ndim, size_t num_elements, - const int64_t *a_strides, const int64_t *a_shape, const int64_t *b_strides, - const int64_t *b_shape, const int64_t *out_strides, const T *grad_output, - const T *input_a, const T *input_b) { +__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, BroadcastMeta meta, + size_t num_elements, const T *grad_output, const T *input_a, const T *input_b) { extern __shared__ char shared_memory[]; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -380,8 +499,8 @@ __global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB float grad_val = 0.0f; if (in_bounds) { - a_offset = CalcOffset(idx, ndim, a_strides, a_shape, out_strides); - b_offset = CalcOffset(idx, ndim, b_strides, b_shape, out_strides); + a_offset = CalcOffset(idx, meta.ndim, meta.a_strides, meta.a_shape, meta.out_strides); + b_offset = CalcOffset(idx, meta.ndim, meta.b_strides, meta.b_shape, meta.out_strides); a_val = input_a ? input_a[a_offset] : T(0); b_val = input_b ? input_b[b_offset] : T(0); output_a[a_offset] = Mul(grad_output[idx], fn_a(a_val, b_val)); @@ -423,10 +542,9 @@ __global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB // NOTE(dcj): Specialized BinaryBackwardKernel for low-precision types (__half / bfloat16) template -__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, int ndim, size_t num_elements, - size_t b_num_elements, const int64_t *a_strides, const int64_t *a_shape, - const int64_t *b_strides, const int64_t *b_shape, const int64_t *out_strides, - const T *grad_output, const T *input_a, const T *input_b, bool fast_atomics) { +__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, BroadcastMeta meta, + size_t num_elements, size_t b_num_elements, const T *grad_output, const T *input_a, + const T *input_b, bool fast_atomics) { const int tid = threadIdx.x; const int block_threads = blockDim.x; @@ -447,8 +565,8 @@ __global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB T a_val = T(0), b_val = T(0); if (in_bounds) { - a_offset = CalcOffset(global_idx, ndim, a_strides, a_shape, out_strides); - b_offset = CalcOffset(global_idx, ndim, b_strides, b_shape, out_strides); + a_offset = CalcOffset(global_idx, meta.ndim, meta.a_strides, meta.a_shape, meta.out_strides); + b_offset = CalcOffset(global_idx, meta.ndim, meta.b_strides, meta.b_shape, meta.out_strides); a_val = input_a ? input_a[a_offset] : T(0); b_val = input_b ? input_b[b_offset] : T(0); @@ -526,51 +644,64 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out const T *grad_output_ptr = static_cast(grad_output->DataPtr()); const auto &out_dims = grad_output->Dims(); - int ndim = out_dims.size(); - - std::vector a_shape(ndim, 1), b_shape(ndim, 1), out_shape(ndim, 1); - std::copy_backward(a_dims.begin(), a_dims.end(), a_shape.end()); - std::copy_backward(b_dims.begin(), b_dims.end(), b_shape.end()); - std::copy_backward(out_dims.begin(), out_dims.end(), out_shape.end()); - - auto a_stride_host = ComputeStrides(a_shape); - auto b_stride_host = ComputeStrides(b_shape); - auto out_stride_host = ComputeStrides(out_shape); - - int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), stream); - - int64_t *device_a_strides, *device_b_strides, *device_out_strides, *device_a_shape, *device_b_shape; - device_a_strides = device_buffer + ndim * 0; - device_b_strides = device_buffer + ndim * 1; - device_out_strides = device_buffer + ndim * 2; - device_a_shape = device_buffer + ndim * 3; - device_b_shape = device_buffer + ndim * 4; - - std::vector host_buffer; - host_buffer.insert(host_buffer.end(), a_stride_host.begin(), a_stride_host.end()); - host_buffer.insert(host_buffer.end(), b_stride_host.begin(), b_stride_host.end()); - host_buffer.insert(host_buffer.end(), out_stride_host.begin(), out_stride_host.end()); - host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); - host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); + const size_t num_elements = grad_output->NumElements(); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + // Fast path: no broadcast, contiguous — skip cudaMalloc/Memcpy/CalcOffset. + // The IsContiguous() guard ensures non-contiguous grad_output falls back to the broadcast + // path, keeping the fast path correct when non-contiguous support is added later. + if (ShapesEqual(a_dims, b_dims) && ShapesEqual(a_dims, out_dims) && grad_output->IsContiguous()) { + auto extract_ptrs = [](const auto &...ts) { + return std::make_tuple(static_cast(ts ? ts->DataPtr() : nullptr)...); + }; + auto [input_a_ptr, input_b_ptr] = extract_ptrs(inputs...); + + constexpr int VecSize = kVecSize; + // Use vectorized kernel if all pointers are 16-byte aligned and numel is large enough + const bool can_vectorize + = (num_elements >= static_cast(VecSize)) + && (reinterpret_cast(output_a_ptr) % (sizeof(T) * VecSize) == 0) + && (reinterpret_cast(output_b_ptr) % (sizeof(T) * VecSize) == 0) + && (reinterpret_cast(grad_output_ptr) % (sizeof(T) * VecSize) == 0) + && (!input_a_ptr || reinterpret_cast(input_a_ptr) % (sizeof(T) * VecSize) == 0) + && (!input_b_ptr || reinterpret_cast(input_b_ptr) % (sizeof(T) * VecSize) == 0); + + if (can_vectorize) { + const size_t num_vecs = num_elements / VecSize; + dim3 block_dims(std::min(static_cast(256), std::min(num_vecs, static_cast(1024)))); + dim3 grid_dims(std::min(CEIL_DIV(num_vecs, block_dims.x), static_cast(65535))); + BinaryBackwardKernelNoBroadcastVectorized<<>>( + output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr); + } else { + dim3 block_dims(std::min(BLOCK_SIZE, static_cast(1024))); + dim3 grid_dims(std::min(CEIL_DIV(num_elements, block_dims.x), static_cast(65535))); + BinaryBackwardKernelNoBroadcastFast<<>>( + output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr); + } + return; + } - const size_t num_elements = grad_output->NumElements(); + // Broadcast path: pass strides/shapes by value via kernel parameter memory. + // This avoids the per-call cudaMallocAsync/cudaMemcpyAsync/cudaFreeAsync that previously + // dominated the host-side jitter floor (especially under LoRA training). + BroadcastMeta meta = MakeBroadcastMeta(a_dims, b_dims, out_dims); if constexpr (std::is_same_v) { LaunchKernel( - [=](dim3 grid, dim3 block, size_t offset, auto... ptrs) { + [=](dim3 grid, dim3 block, size_t /*offset*/, auto... ptrs) { const int num_warps = BLOCK_SIZE / kWarpSize; const size_t smem_size = num_warps * sizeof(cub::WarpReduce::TempStorage); - BinaryBackwardKernel<<>>( - output_a_ptr, output_b_ptr, fun_a, fun_b, ndim, num_elements, device_a_strides, device_a_shape, - device_b_strides, device_b_shape, device_out_strides, grad_output_ptr, ptrs...); + BinaryBackwardKernel<<>>(output_a_ptr, output_b_ptr, fun_a, fun_b, meta, + num_elements, grad_output_ptr, ptrs...); }, output_a, inputs...); } else if constexpr (std::is_same_v || std::is_same_v) { // Dynamically choose the most efficient bf16/half strategy based on broadcast pattern. - // Compute b_num_elements, which also serves as K for the histogram path. + // Reconstruct right-aligned b_shape (stack-only, no device allocations) for + // DecideBF16Path which still operates on std::vector. + const int ndim = meta.ndim; + std::vector b_shape(meta.b_shape, meta.b_shape + ndim); + const std::vector &out_shape = out_dims; + size_t b_num_elements = 1; for (auto v : b_shape) { b_num_elements *= static_cast(v); } const int K_linear = static_cast(b_num_elements); @@ -583,8 +714,7 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out LaunchKernel( [=](dim3 grid, dim3 block, size_t /*offset*/, auto... ptrs) { BinaryBackwardKernelNoBroadcast<<>>( - output_a_ptr, output_b_ptr, fun_a, fun_b, ndim, num_elements, device_a_strides, device_a_shape, - device_b_strides, device_b_shape, device_out_strides, grad_output_ptr, ptrs...); + output_a_ptr, output_b_ptr, fun_a, fun_b, meta, num_elements, grad_output_ptr, ptrs...); }, output_a, inputs...); return; @@ -594,10 +724,9 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out // Small K with variation in the innermost dimension: use two-pass histogram strategy. LaunchKernel( [=](dim3 /*grid*/, dim3 /*block*/, size_t /*offset*/, const T *input_a_ptr, const T *input_b_ptr) { - BinaryBackwardBhistLaunch( - fun_a, fun_b, output_a_ptr, output_b_ptr, grad_output_ptr, ndim, num_elements, K_linear, - device_a_strides, device_a_shape, device_b_strides, device_b_shape, device_out_strides, - input_a_ptr, input_b_ptr, stream); + BinaryBackwardBhistLaunch(fun_a, fun_b, output_a_ptr, output_b_ptr, + grad_output_ptr, meta, num_elements, K_linear, + input_a_ptr, input_b_ptr, stream); }, output_a, inputs...); @@ -606,17 +735,15 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out // Otherwise fall back to the block-reduction kernel with SoA layout and fast atomics. LaunchKernel( - [=](dim3 grid, dim3 block, size_t offset, auto... ptrs) { + [=](dim3 grid, dim3 block, size_t /*offset*/, auto... ptrs) { const int padded_block = BLOCK_SIZE + BLOCK_SIZE / kWarpSize; const size_t smem_size = static_cast(padded_block) * (sizeof(int64_t) + sizeof(float)); BinaryBackwardKernel<<>>( - output_a_ptr, output_b_ptr, fun_a, fun_b, ndim, num_elements, output_b->NumElements(), - device_a_strides, device_a_shape, device_b_strides, device_b_shape, device_out_strides, + output_a_ptr, output_b_ptr, fun_a, fun_b, meta, num_elements, output_b->NumElements(), grad_output_ptr, ptrs..., /*fast_atomics=*/true); }, output_a, inputs...); } - cudaFreeAsync(device_buffer, stream); } template std::shared_ptr UnaryForward(const std::shared_ptr &input, Func unary_fn) { @@ -649,20 +776,11 @@ std::shared_ptr UnaryBackward(const std::shared_ptr &grad_output auto output = std::make_shared(grad_output->Dims(), promoted_type, grad_output->GetDevice()); switch (promoted_type) { - DISPATCH_CASE(WRAP({ - output->Fill(0.0f); - LaunchBackward<256, float>(unary_fn, output, grad_output_promoted, a_promoted); - }), + DISPATCH_CASE(WRAP({ LaunchBackward<256, float>(unary_fn, output, grad_output_promoted, a_promoted); }), DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - output->Fill(0); - LaunchBackward<256, nv_bfloat16>(unary_fn, output, grad_output_promoted, a_promoted); - }), + DISPATCH_CASE(WRAP({ LaunchBackward<256, nv_bfloat16>(unary_fn, output, grad_output_promoted, a_promoted); }), DataType::kBFLOAT16) - DISPATCH_CASE(WRAP({ - output->Fill(0); - LaunchBackward<256, int64_t>(unary_fn, output, grad_output_promoted, a_promoted); - }), + DISPATCH_CASE(WRAP({ LaunchBackward<256, int64_t>(unary_fn, output, grad_output_promoted, a_promoted); }), DataType::kINT64) default: LOG_LOC(FATAL, "CUDA unary backward: 'Unsupported data type'"); @@ -718,11 +836,10 @@ BinaryBackward(const std::shared_ptr &grad_output, const std::shared_ptr auto a_dtype = a_promoted ? a_promoted->Dtype() : dtype; auto b_dtype = b_promoted ? b_promoted->Dtype() : dtype; - DataType promoted_type - = DispatchFunc, DataTypeList, DataTypeList>( - {a_dtype, b_dtype, dtype}, - [=]() { return DataTypeMap_v>; }, - "CUDA BinaryBackward"); + // Compute dtype determined by saved tensors (forward compute dtype), not grad_output + DataType promoted_type = DispatchFunc, DataTypeList>( + {a_dtype, b_dtype}, [=]() { return DataTypeMap_v>; }, + "CUDA BinaryBackward"); CHECK(a_num_elements >= b_num_elements && a_num_elements % b_num_elements == 0); @@ -743,17 +860,25 @@ BinaryBackward(const std::shared_ptr &grad_output, const std::shared_ptr auto grad_a = std::make_shared(a_dims, promoted_type, device); auto grad_b = std::make_shared(b_dims, promoted_type, device); + // Only Fill(0) when broadcast is needed (atomicAdd requires zero-init). + // The no-broadcast fast path writes every element directly. + const bool needs_broadcast = !ShapesEqual(a_dims, b_dims) || !ShapesEqual(a_dims, grad_output->Dims()); + switch (promoted_type) { DISPATCH_CASE(WRAP({ - grad_a->Fill(0.0f); - grad_b->Fill(0.0f); + if (needs_broadcast) { + grad_a->Fill(0.0f); + grad_b->Fill(0.0f); + } LaunchBackward<256, float>(fn_a, fn_b, grad_a, grad_b, a_dims, b_dims, grad_output_promoted, a_promoted, b_promoted); }), DataType::kFLOAT32) DISPATCH_CASE(WRAP({ - grad_a->Fill(0); - grad_b->Fill(0); + if (needs_broadcast) { + grad_a->Fill(0); + grad_b->Fill(0); + } LaunchBackward<256, nv_bfloat16>(fn_a, fn_b, grad_a, grad_b, a_dims, b_dims, grad_output_promoted, a_promoted, b_promoted); }), diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index 334b257c..5b9b6781 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -6,6 +6,7 @@ #include #include +#include "infini_train/include/autograd/linear.h" #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/core/runtime/device_guard.h" @@ -91,16 +92,15 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptrDtype(); auto other_dtype = other->Dtype(); auto grad_output_dtype = grad_output->Dtype(); - DataType promoted_type - = DispatchFunc, DataTypeList, DataTypeList>( - {input_dtype, other_dtype, grad_output_dtype}, - [=]() { return DataTypeMap_v>; }, - "CUDA MatmulBackward"); - - auto input_promoted = input_dtype == promoted_type ? input : std::make_shared(input->To(promoted_type)); - auto other_promoted = other_dtype == promoted_type ? other : std::make_shared(other->To(promoted_type)); + // Compute dtype determined by saved tensors (forward compute dtype), not grad_output + DataType compute_dtype = DispatchFunc, DataTypeList>( + {input_dtype, other_dtype}, [=]() { return DataTypeMap_v>; }, + "CUDA MatmulBackward"); + + auto input_promoted = input_dtype == compute_dtype ? input : std::make_shared(input->To(compute_dtype)); + auto other_promoted = other_dtype == compute_dtype ? other : std::make_shared(other->To(compute_dtype)); auto grad_output_promoted - = grad_output_dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); const auto &input_dims = input->Dims(); const auto &other_dims = other->Dims(); @@ -123,16 +123,12 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptr(input_dims, promoted_type, grad_output->GetDevice()); - auto grad_other = std::make_shared(other_dims, promoted_type, grad_output->GetDevice()); + // For bf16 compute, output in fp32 to preserve accumulation precision (matches PyTorch behavior) + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); + auto grad_other = std::make_shared(other_dims, output_dtype, grad_output->GetDevice()); - DispatchFunc( - promoted_type, - [=]() { - grad_input->Fill(0); - grad_other->Fill(0); - }, - "CUDA MatmulBackward"); + // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. auto device = input_promoted->GetDevice(); const float alpha = 1.0f, beta = 0.0f; @@ -151,18 +147,17 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptrDataPtr(), CUDA_R_32F, lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kFLOAT32) - DISPATCH_CASE( - WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_16BF, lda, - stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_input->DataPtr(), - CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_16BF, + lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, + grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) } } @@ -177,18 +172,17 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptrDataPtr(), CUDA_R_32F, lda, stride_a, input_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kFLOAT32) - DISPATCH_CASE( - WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), CUDA_R_16BF, - lda, stride_a, input_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_other->DataPtr(), - CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), + CUDA_R_16BF, lda, stride_a, input_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, + grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) } } @@ -303,8 +297,9 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons return output; } -template -__global__ void ReduceColumnsKernel(const T *__restrict__ input, T *__restrict__ output, int num_rows, int num_cols) { +template +__global__ void ReduceColumnsKernel(const TIn *__restrict__ input, TOut *__restrict__ output, int num_rows, + int num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -324,134 +319,175 @@ __global__ void ReduceColumnsKernel(const T *__restrict__ input, T *__restrict__ std::tuple, std::shared_ptr, std::shared_ptr> LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, - int64_t out_features, const std::shared_ptr &grad_output, const bool bias) { - const auto &input_dims = input->Dims(); + int64_t in_features, int64_t out_features, const std::vector &input_dims, + const std::shared_ptr &grad_output, bool bias, + infini_train::autograd::LinearGradFlags grad_flags) { + const auto compute_grad_input = grad_flags.input; + const auto compute_grad_weight = grad_flags.weight; + const auto compute_grad_bias = grad_flags.bias; + CHECK_GE(input_dims.size(), 2); const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies{}); - const int64_t in_features = *input_dims.rbegin(); + + const std::vector weight_dims + = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; auto dtype = grad_output->Dtype(); - auto input_dtype = input->Dtype(); - auto weight_dtype = weight->Dtype(); - DataType promoted_type - = DispatchFunc, DataTypeList, DataTypeList>( - {input_dtype, weight_dtype, dtype}, - [=]() { return DataTypeMap_v>; }, - "CUDA LinearBackward"); - - auto input_promoted = input_dtype == promoted_type ? input : std::make_shared(input->To(promoted_type)); - auto weight_promoted = weight_dtype == promoted_type ? weight : std::make_shared(weight->To(promoted_type)); + + // For type promotion, use available tensors + DataType input_dtype = input ? input->Dtype() : (weight ? weight->Dtype() : dtype); + DataType weight_dtype = weight ? weight->Dtype() : (input ? input->Dtype() : dtype); + // Compute dtype determined by saved tensors (forward compute dtype), not grad_output + DataType compute_dtype = DispatchFunc, DataTypeList>( + {input_dtype, weight_dtype}, [=]() { return DataTypeMap_v>; }, + "CUDA LinearBackward"); + auto grad_output_promoted - = dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); + = dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - const auto &weight_dims = weight->Dims(); - CHECK_EQ(weight_dims.size(), 2); - CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]); - CHECK_EQ(out_features, weight_dims[transpose ? 0 : 1]); + // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; - auto grad_input = std::make_shared(input_dims, promoted_type, grad_output->GetDevice()); - auto grad_weight = std::make_shared(weight_dims, promoted_type, grad_output->GetDevice()); + // Allocate only needed gradient tensors (selective save: input/weight may be nullptr). + std::shared_ptr grad_input = nullptr; + std::shared_ptr grad_weight = nullptr; std::shared_ptr grad_bias = nullptr; - auto initialize_gradients = [&](auto zero_value, DataType dtype) { - using T = decltype(zero_value); - grad_input->Fill(zero_value); - grad_weight->Fill(zero_value); - if (bias) { - grad_bias = std::make_shared(std::vector{out_features}, dtype, grad_output->GetDevice()); - grad_bias->Fill(zero_value); - } - }; - DispatchFunc( - promoted_type, [=]() { initialize_gradients(T(0), promoted_type); }, "CUDA LinearBackward"); + if (compute_grad_input) { + grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); + } + if (compute_grad_weight) { + grad_weight = std::make_shared(weight_dims, output_dtype, grad_output->GetDevice()); + } + // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output, and ReduceColumnsKernel assigns directly. + if (compute_grad_bias && bias) { + grad_bias + = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); + } - auto device = input_promoted->GetDevice(); + auto device = grad_output->GetDevice(); const auto &cuda_stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); float alpha = 1.0f; float beta = 0.0f; - auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; - auto trans_b1 = CUBLAS_OP_N; - auto lda1 = transpose ? in_features : out_features; - auto trans_a2 = CUBLAS_OP_N; - auto trans_b2 = CUBLAS_OP_T; - int m2 = transpose ? in_features : out_features; - int n2 = transpose ? out_features : in_features; - const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); - const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); - auto lda2 = transpose ? in_features : out_features; - auto ldb2 = transpose ? out_features : in_features; - auto ldc2 = transpose ? in_features : out_features; cublasHandle_t handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); - switch (promoted_type) { + switch (compute_dtype) { // TODO(zbl): use cublasSgemv if possible + DISPATCH_CASE( + WRAP({ + if (compute_grad_input) { + // - if transpose: + // weight is [out_features, in_features] here + // d_input = d_output * weight --> d_input.T = weight.T * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[in_features, out_features] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // weight is [in_features, out_features] here + // d_input = d_output * weight.T --> d_input.T = weight * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[out_features, in_features] + // B = d_output.T[out_features, bs] + CHECK(weight != nullptr) + << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; + auto weight_promoted + = weight_dtype == compute_dtype ? weight : std::make_shared(weight->To(compute_dtype)); + auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; + auto lda1 = transpose ? in_features : out_features; + CUBLAS_CHECK(cublasSgemm(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, &alpha, + static_cast(weight_promoted->DataPtr()), lda1, + static_cast(grad_output_promoted->DataPtr()), out_features, + &beta, static_cast(grad_input->DataPtr()), in_features)); + } + if (compute_grad_weight) { + // - if transpose: + // d_weight = d_output.T * input --> d_weight.T = input.T * d_output + // C = d_weight.T[in_features, out_features] + // A = input.T[in_features, bs] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // d_weight = input.T * d_output --> d_weight.T = d_output.T * input + // C = d_weight.T[out_features, in_features] + // A = d_output.T[out_features, bs] + // B = input.T[in_features, bs] + CHECK(input != nullptr) + << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; + auto input_promoted + = input_dtype == compute_dtype ? input : std::make_shared(input->To(compute_dtype)); + auto trans_a2 = CUBLAS_OP_N; + auto trans_b2 = CUBLAS_OP_T; + int m2 = transpose ? in_features : out_features; + int n2 = transpose ? out_features : in_features; + const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); + const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); + auto lda2 = transpose ? in_features : out_features; + auto ldb2 = transpose ? out_features : in_features; + auto ldc2 = transpose ? in_features : out_features; + CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, + static_cast(a2), lda2, static_cast(b2), ldb2, + &beta, static_cast(grad_weight->DataPtr()), ldc2)); + } + // d_bias = \sum_i(i=0, bs-1) d_output[i] + // TODO(dcj): use thrust::fill or reduce kernel do this + if (compute_grad_bias && bias) { + constexpr int BLOCK_SIZE = 256; + int threads_per_block = BLOCK_SIZE; + int num_blocks = out_features; + ReduceColumnsKernel<<>>( + static_cast(grad_output_promoted->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); + } + }), + DataType::kFLOAT32) DISPATCH_CASE(WRAP({ - // - if transpose: - // weight is [out_features, in_features] here - // d_input = d_output * weight --> d_input.T = weight.T * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[in_features, out_features] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // weight is [in_features, out_features] here - // d_input = d_output * weight.T --> d_input.T = weight * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[out_features, in_features] - // B = d_output.T[out_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, - static_cast(weight_promoted->DataPtr()), lda1, - static_cast(grad_output_promoted->DataPtr()), - out_features, &beta, static_cast(grad_input->DataPtr()), - in_features)); - // - if transpose: - // d_weight = d_output.T * input --> d_weight.T = input.T * d_output - // C = d_weight.T[in_features, out_features] - // A = input.T[in_features, bs] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // d_weight = input.T * d_output --> d_weight.T = d_output.T * input - // C = d_weight.T[out_features, in_features] - // A = d_output.T[out_features, bs] - // B = input.T[in_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, - static_cast(a2), lda2, static_cast(b2), - ldb2, &beta, static_cast(grad_weight->DataPtr()), ldc2)); - // d_bias = \sum_i(i=0, bs-1) d_output[i] - // TODO(dcj): use thrust::fill or reduce kernel do this - if (bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); + if (compute_grad_input) { + CHECK(weight != nullptr) + << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; + auto weight_promoted = weight_dtype == compute_dtype + ? weight + : std::make_shared(weight->To(compute_dtype)); + auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; + auto lda1 = transpose ? in_features : out_features; + CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, + &alpha, weight_promoted->DataPtr(), CUDA_R_16BF, lda1, + grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, + &beta, grad_input->DataPtr(), CUDA_R_32F, in_features, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); } - }), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, - weight_promoted->DataPtr(), CUDA_R_16BF, lda1, - grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, &beta, - grad_input->DataPtr(), CUDA_R_16BF, in_features, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT)); - CUBLAS_CHECK(cublasGemmEx(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, a2, CUDA_R_16BF, - lda2, b2, CUDA_R_16BF, ldb2, &beta, grad_weight->DataPtr(), - CUDA_R_16BF, ldc2, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - if (bias) { + if (compute_grad_weight) { + CHECK(input != nullptr) + << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; + auto input_promoted = input_dtype == compute_dtype + ? input + : std::make_shared(input->To(compute_dtype)); + auto trans_a2 = CUBLAS_OP_N; + auto trans_b2 = CUBLAS_OP_T; + int m2 = transpose ? in_features : out_features; + int n2 = transpose ? out_features : in_features; + const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); + const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); + auto lda2 = transpose ? in_features : out_features; + auto ldb2 = transpose ? out_features : in_features; + auto ldc2 = transpose ? in_features : out_features; + CUBLAS_CHECK(cublasGemmEx(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, a2, CUDA_R_16BF, + lda2, b2, CUDA_R_16BF, ldb2, &beta, grad_weight->DataPtr(), + CUDA_R_32F, ldc2, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + } + if (compute_grad_bias && bias) { constexpr int BLOCK_SIZE = 256; int threads_per_block = BLOCK_SIZE; int num_blocks = out_features; ReduceColumnsKernel<<>>( static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); + static_cast(grad_bias->DataPtr()), out_features, bs); } }), DataType::kBFLOAT16) diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index 7d73d893..ae7c9f7b 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -80,19 +80,20 @@ std::tuple, std::shared_ptr> OuterBackward(const auto other_dtype = other->Dtype(); auto grad_output_dtype = grad_output->Dtype(); - DataType promoted_type - = DispatchFunc, DataTypeList, DataTypeList>( - {input_dtype, other_dtype, grad_output_dtype}, - [=]() { return DataTypeMap_v>; }, - "CUDA OuterBackward"); + // Compute dtype determined by saved tensors (forward compute dtype), not grad_output + DataType promoted_type = DispatchFunc, DataTypeList>( + {input_dtype, other_dtype}, [=]() { return DataTypeMap_v>; }, + "CUDA OuterBackward"); auto input_promoted = input_dtype == promoted_type ? input : std::make_shared(input->To(promoted_type)); auto other_promoted = other_dtype == promoted_type ? other : std::make_shared(other->To(promoted_type)); auto grad_output_promoted = grad_output_dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); - auto grad_input = std::make_shared(std::vector{M}, promoted_type, grad_output->GetDevice()); - auto grad_other = std::make_shared(std::vector{N}, promoted_type, grad_output->GetDevice()); + // For bf16 compute, output in fp32 to preserve accumulation precision (matches PyTorch behavior) + auto output_dtype = (promoted_type == DataType::kBFLOAT16) ? DataType::kFLOAT32 : promoted_type; + auto grad_input = std::make_shared(std::vector{M}, output_dtype, grad_output->GetDevice()); + auto grad_other = std::make_shared(std::vector{N}, output_dtype, grad_output->GetDevice()); DispatchFunc( promoted_type, @@ -140,7 +141,7 @@ std::tuple, std::shared_ptr> OuterBackward(const // B = grad_output.T[N, M] CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, M, N, &alpha, other_promoted->DataPtr(), CUDA_R_16BF, 1, grad_output_promoted->DataPtr(), CUDA_R_16BF, N, &beta, - grad_input->DataPtr(), CUDA_R_16BF, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + grad_input->DataPtr(), CUDA_R_32F, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); // grad_other[N, 1] = grad_output.T[N, M] × input[M, 1] // grad_other.T[1, N] = input.T[1, M] × grad_output[M, N] // C = grad_other.T[1, N] @@ -148,7 +149,7 @@ std::tuple, std::shared_ptr> OuterBackward(const // B = grad_output.T[N, M] CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, 1, N, M, &alpha, input_promoted->DataPtr(), CUDA_R_16BF, 1, grad_output_promoted->DataPtr(), CUDA_R_16BF, N, &beta, - grad_other->DataPtr(), CUDA_R_16BF, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + grad_other->DataPtr(), CUDA_R_32F, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); }), DataType::kBFLOAT16) } diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 62f3054a..b1766bbd 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -22,8 +22,20 @@ namespace infini_train::nn::init { namespace { -static std::random_device rd; -static std::mt19937 gen(rd()); +constexpr int kRandomSeed = 42; + +// FIXME: RNG design is incomplete. +// +// Current implementation lacks: +// - unified Generator abstraction +// - global default generator and seed control +// - reproducible / clonable RNG state +// +// TODO: +// - introduce Generator interface and backend impl +// - add default generator management (per device) +// - refactor random ops to consume Generator +static std::mt19937 gen(kRandomSeed); } // namespace std::shared_ptr Normal(const std::shared_ptr &tensor, float mean, float std, @@ -34,7 +46,7 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean #ifdef USE_OMP #pragma omp parallel { - std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num()); + std::mt19937 local_gen(kRandomSeed + omp_get_thread_num()); std::normal_distribution local_dis(mean, std); #pragma omp for for (int i = 0; i < buffer.size(); ++i) { @@ -126,7 +138,7 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, #ifdef USE_OMP #pragma omp parallel { - std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num()); + std::mt19937 local_gen(kRandomSeed + omp_get_thread_num()); std::uniform_real_distribution local_dis(a, b); #pragma omp for for (int i = 0; i < buffer.size(); ++i) { diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 6c243fea..490b9009 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -398,6 +398,11 @@ std::shared_ptr Tensor::Contiguous() { return std::make_shared(dims_)->Apply({shared_from_this()})[0]; } +// FIXME: Requires stride tracking in the Tensor class before this can be implemented +// correctly. Currently always returns true as a placeholder. The contiguous guard in +// elementwise.cu ensures non-contiguous tensors fall back to the broadcast path. +bool Tensor::IsContiguous() const { return true; } + std::shared_ptr Tensor::Flatten(int64_t start, int64_t end) { auto ndim = dims_.size(); auto start_dim = start >= 0 ? start : start + ndim; diff --git a/scripts/compare_loss.py b/scripts/compare_loss.py index 31b2a009..900d1325 100755 --- a/scripts/compare_loss.py +++ b/scripts/compare_loss.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # Usage: # python tools/compare_loss.py \ -# /data/shared/InfiniTrain-dev/logs/202511_a800/20260105/feature/add_1F1B_f2a383a/logs \ -# /data/shared/InfiniTrain-dev/logs/202511_a800/20251223/feature/tp-pp-split-stream/logs \ +# /path/to/baseline/logs \ +# /path/to/test/logs \ # --threshold-fp32 1e-5 --threshold-bf16 1e-2 import re @@ -50,8 +50,8 @@ def compare_files(file1, file2, threshold): def main(): parser = ArgumentParser(description='Compare training loss between two log directories') - parser.add_argument('dir1', type=Path, help='First log directory') - parser.add_argument('dir2', type=Path, help='Second log directory') + parser.add_argument('dir1', type=Path, help='Baseline log directory') + parser.add_argument('dir2', type=Path, help='Test log directory') parser.add_argument('--threshold', type=float, help='Loss difference threshold (deprecated, use --threshold-fp32 and --threshold-bf16)') parser.add_argument('--threshold-fp32', type=float, default=1e-5, help='Loss difference threshold for fp32 (default: 1e-5)') parser.add_argument('--threshold-bf16', type=float, default=1e-2, help='Loss difference threshold for bfloat16 (default: 1e-2)') @@ -63,6 +63,10 @@ def main(): args.threshold_fp32 = args.threshold args.threshold_bf16 = args.threshold + print(f"Baseline: {args.dir1.resolve()}") + print(f"Test: {args.dir2.resolve()}") + print() + files1, duplicates1 = collect_log_files(args.dir1) files2, duplicates2 = collect_log_files(args.dir2) exit_if_duplicate_logs(args.dir1, duplicates1) @@ -73,9 +77,9 @@ def main(): common = set(files1.keys()) & set(files2.keys()) if only_in_1: - print(f"Files only in {args.dir1.resolve()}: {', '.join(sorted(only_in_1))}") + print(f"Files only in baseline: {', '.join(sorted(only_in_1))}") if only_in_2: - print(f"Files only in {args.dir2.resolve()}: {', '.join(sorted(only_in_2))}") + print(f"Files only in test: {', '.join(sorted(only_in_2))}") if only_in_1 or only_in_2: print() diff --git a/scripts/compare_tps.py b/scripts/compare_tps.py index de6327de..d64d8c9d 100755 --- a/scripts/compare_tps.py +++ b/scripts/compare_tps.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # Usage: # python tools/compare_tps.py \ -# /path/to/logs/dir1 \ -# /path/to/logs/dir2 \ +# /path/to/baseline/logs \ +# /path/to/test/logs \ # --threshold 0.20 import re @@ -32,30 +32,43 @@ def compare_files(file1, file2, threshold): tps2 = {k: v for k, v in tps2.items() if k > 1} if not tps1 or not tps2: - return 0, 1, [" No valid steps found (after excluding step 1)"], 0, 0, 0 + return 0, True, [" No valid steps found (after excluding step 1)"], 0, 0, 0, 0, 0 # Calculate averages avg1 = sum(tps1.values()) / len(tps1) avg2 = sum(tps2.values()) / len(tps2) - # Calculate relative error - rel_error = abs(avg1 - avg2) / max(avg1, avg2) if max(avg1, avg2) > 0 else 0 - - mismatches = [] - if rel_error > threshold: - mismatches.append(f" Average tok/s: {avg1:.2f} vs {avg2:.2f} ✗ (error: {rel_error*100:.2f}%, threshold: {threshold*100:.2f}%)") - mismatches.append(f" Steps compared: {len(tps1)} vs {len(tps2)} (excluding step 1)") + # Calculate signed relative change of test vs baseline: positive means test faster, negative means test slower + signed_change = (avg2 - avg1) / avg1 if avg1 > 0 else 0 + + messages = [] + failed = False + if abs(signed_change) > threshold: + sign = "+" if signed_change >= 0 else "" + if signed_change < 0: + # test slower than baseline -> failure + label = "✗ SLOWER" + failed = True + else: + # test faster than baseline -> pass but notify + label = "↑ FASTER" + messages.append(f" Average tok/s: {avg1:.2f} (baseline) vs {avg2:.2f} (test) {label} ({sign}{signed_change*100:.1f}%, threshold: ±{threshold*100:.0f}%)") + messages.append(f" Steps compared: {len(tps1)} vs {len(tps2)} (excluding step 1)") - return 1, len(mismatches), mismatches, avg1, avg2, rel_error + return 1, failed, messages, avg1, avg2, signed_change, len(tps1), len(tps2) def main(): parser = ArgumentParser(description='Compare tok/s between two log directories') - parser.add_argument('dir1', type=Path, help='First log directory') - parser.add_argument('dir2', type=Path, help='Second log directory') + parser.add_argument('dir1', type=Path, help='Baseline log directory') + parser.add_argument('dir2', type=Path, help='Test log directory') parser.add_argument('--threshold', type=float, default=0.20, help='Relative error threshold (default: 0.20 = 20%%)') parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones') args = parser.parse_args() + print(f"Baseline: {args.dir1.resolve()}") + print(f"Test: {args.dir2.resolve()}") + print() + files1, duplicates1 = collect_log_files(args.dir1) files2, duplicates2 = collect_log_files(args.dir2) exit_if_duplicate_logs(args.dir1, duplicates1) @@ -66,43 +79,70 @@ def main(): common = set(files1.keys()) & set(files2.keys()) if only_in_1: - print(f"Files only in {args.dir1.resolve()}: {', '.join(sorted(only_in_1))}") + print(f"Files only in baseline: {', '.join(sorted(only_in_1))}") if only_in_2: - print(f"Files only in {args.dir2.resolve()}: {', '.join(sorted(only_in_2))}") + print(f"Files only in test: {', '.join(sorted(only_in_2))}") if only_in_1 or only_in_2: print() - total_mismatches = 0 - total_files = 0 - passed_files = 0 + improvements = [] # (name, avg1, avg2, signed_change) + regressions = [] + normal = [] for name in sorted(common): - total_files += 1 - total_comparisons, num_mismatches, mismatches, avg1, avg2, rel_error = compare_files(files1[name], files2[name], args.threshold) - - if mismatches: - print(f"Comparing {name}:") - for msg in mismatches: - print(msg) - total_mismatches += num_mismatches + _, failed, messages, avg1, avg2, signed_change, steps1, steps2 = compare_files(files1[name], files2[name], args.threshold) + if failed: + regressions.append((name, avg1, avg2, signed_change, steps1, steps2)) + elif messages: + improvements.append((name, avg1, avg2, signed_change, steps1, steps2)) else: - passed_files += 1 - # Only print details when verbose mode is enabled - if args.verbose: - print(f"Comparing {name}:") - print(f" Average tok/s: {avg1:.2f} vs {avg2:.2f} ✓ (error: {rel_error*100:.2f}%, threshold: {args.threshold*100:.2f}%)") - print(f" Steps compared: {len([k for k in parse_log(files1[name]) if k > 1])} (excluding step 1)") - - # Print separator when there are mismatches or verbose mode - if mismatches or args.verbose: + normal.append((name, avg1, avg2, signed_change, steps1, steps2)) + + pct = f"{args.threshold*100:.0f}%" + + if improvements: + print(f"{'=' * 50}") + print(f"Large improvements (>{pct})") + print(f"{'=' * 50}") + for name, avg1, avg2, signed_change, steps1, steps2 in improvements: + sign = "+" if signed_change >= 0 else "" + print(f"[PASS] {name}") + print(f" average tok/s: {avg1:.2f} vs {avg2:.2f} ({sign}{signed_change*100:.1f}%)") + print(f" effective steps: {steps1} vs {steps2} (step > 1)") print() - print("=" * 50) - print(f"Overall Summary:") - print(f" {passed_files}/{total_files} test cases passed (threshold: {args.threshold*100:.0f}%)") - print("=" * 50) + if regressions: + print(f"{'=' * 50}") + print(f"FAILED regressions (>{pct})") + print(f"{'=' * 50}") + for name, avg1, avg2, signed_change, steps1, steps2 in regressions: + sign = "+" if signed_change >= 0 else "" + print(f"[FAIL] {name}") + print(f" average tok/s: {avg1:.2f} vs {avg2:.2f} ({sign}{signed_change*100:.1f}%)") + print(f" effective steps: {steps1} vs {steps2} (step > 1)") + print() + + if args.verbose and normal: + print(f"{'=' * 50}") + print(f"Within threshold (<={pct})") + print(f"{'=' * 50}") + for name, avg1, avg2, signed_change, steps1, steps2 in normal: + sign = "+" if signed_change >= 0 else "" + print(f"[PASS] {name}") + print(f" tok/s: {avg1:.2f} vs {avg2:.2f} ({sign}{signed_change*100:.1f}%)") + print() - sys.exit(1 if total_mismatches > 0 else 0) + total = len(improvements) + len(regressions) + len(normal) + passed = len(improvements) + len(normal) + print(f"{'=' * 50}") + print(f"Summary: {passed}/{total} test cases passed") + print(f"failed regressions : {len(regressions)}") + print(f"large improvements : {len(improvements)}") + print(f"within threshold : {len(normal)}") + print(f"total cases : {total}") + print(f"{'=' * 50}") + + sys.exit(1 if regressions else 0) if __name__ == '__main__': main() diff --git a/test/hook/test_hook.cc b/test/hook/test_hook.cc deleted file mode 100644 index 32c7e097..00000000 --- a/test/hook/test_hook.cc +++ /dev/null @@ -1,179 +0,0 @@ -#include -#include - -#include "glog/logging.h" - -#include "infini_train/include/autograd/elementwise.h" -#include "infini_train/include/autograd/function.h" -#include "infini_train/include/autograd/function_hook.h" -#include "infini_train/include/common/hook.h" -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/tensor.h" - -using namespace infini_train; - -// ============================================================================ -// Test 1: Basic Module Hooks -// ============================================================================ -void test_basic_hooks() { - std::cout << "\n=== Test 1: Basic Module Hooks ===" << std::endl; - - auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); - x->set_requires_grad(true); - - // Module hook example - class MyModule : public nn::Module { - public: - MyModule() : Module("MyModule") {} - - std::vector> Forward(const std::vector> &inputs) override { - std::cout << "Forward pass executing..." << std::endl; - return inputs; - } - }; - - auto module = std::make_shared(); - - // Register forward pre-hook - auto pre_hook - = module->RegisterForwardPreHook([](nn::Module *mod, const std::vector> &inputs) { - std::cout << "Forward pre-hook: Module type = " << mod->type() << std::endl; - }); - - // Register forward post-hook - auto fwd_hook - = module->RegisterForwardPostHook([](nn::Module *mod, const std::vector> &inputs, - const std::vector> &outputs) { - std::cout << "Forward post-hook: Got " << outputs.size() << " outputs" << std::endl; - }); - - // Register backward pre-hook - auto bwd_pre_hook = module->RegisterBackwardPreHook( - [](nn::Module *mod, const std::vector> &grad_outputs) { - std::cout << "Backward pre-hook called!" << std::endl; - }); - - // Register backward post-hook - auto bwd_post_hook - = module->RegisterBackwardPostHook([](nn::Module *mod, const std::vector> &grad_inputs, - const std::vector> &grad_outputs) { - std::cout << "Backward post-hook called!" << std::endl; - }); - - // Test forward pass - std::vector> inputs = {x}; - auto outputs = (*module)(inputs); - - std::cout << "Module hook test completed!" << std::endl; -} - -// ============================================================================ -// Test 2: Hook Remove() Functionality Test -// ============================================================================ -void test_hook_remove() { - std::cout << "\n=== Test 2: Hook Remove() Functionality Test ===" << std::endl; - - auto a = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32); - auto b = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32); - a->set_requires_grad(true); - b->set_requires_grad(true); - - int hook1_count = 0; - int hook2_count = 0; - int hook3_count = 0; - - auto add_fn = std::make_shared(); - - // Register three forward pre-hooks - auto handle1 = add_fn->RegisterForwardPreHook( - [&hook1_count](autograd::Function *, const std::vector> &) { - hook1_count++; - std::cout << "Hook 1 called (count: " << hook1_count << ")" << std::endl; - }); - - auto handle2 = add_fn->RegisterForwardPreHook( - [&hook2_count](autograd::Function *, const std::vector> &) { - hook2_count++; - std::cout << "Hook 2 called (count: " << hook2_count << ")" << std::endl; - }); - - auto handle3 = add_fn->RegisterForwardPreHook( - [&hook3_count](autograd::Function *, const std::vector> &) { - hook3_count++; - std::cout << "Hook 3 called (count: " << hook3_count << ")" << std::endl; - }); - - // First call - all hooks should fire - std::cout << "\n--- First Apply (all hooks active) ---" << std::endl; - std::vector> inputs; - inputs.push_back(a); - inputs.push_back(b); - auto result1 = add_fn->Apply(inputs); - std::cout << "Hook counts: " << hook1_count << ", " << hook2_count << ", " << hook3_count << std::endl; - - // Remove hook 2 - std::cout << "\n--- Removing Hook 2 ---" << std::endl; - handle2->Remove(); - - // Second call - hook 2 should not fire - std::cout << "\n--- Second Apply (hook 2 removed) ---" << std::endl; - auto result2 = add_fn->Apply(inputs); - std::cout << "Hook counts: " << hook1_count << ", " << hook2_count << ", " << hook3_count << std::endl; - - // Remove hook 1 - std::cout << "\n--- Removing Hook 1 ---" << std::endl; - handle1->Remove(); - - // Third call - only hook 3 should fire - std::cout << "\n--- Third Apply (hooks 1 and 2 removed) ---" << std::endl; - auto result3 = add_fn->Apply(inputs); - std::cout << "Hook counts: " << hook1_count << ", " << hook2_count << ", " << hook3_count << std::endl; - - // Verify results - std::cout << "\n=== Test Results ===" << std::endl; - bool test_passed = true; - - if (hook1_count != 2) { - std::cout << "FAIL: Hook 1 should be called 2 times, got " << hook1_count << std::endl; - test_passed = false; - } - - if (hook2_count != 1) { - std::cout << "FAIL: Hook 2 should be called 1 time, got " << hook2_count << std::endl; - test_passed = false; - } - - if (hook3_count != 3) { - std::cout << "FAIL: Hook 3 should be called 3 times, got " << hook3_count << std::endl; - test_passed = false; - } - - if (test_passed) { - std::cout << "SUCCESS: All hooks behaved correctly!" << std::endl; - std::cout << " - Hook 1: called 2 times (before removal)" << std::endl; - std::cout << " - Hook 2: called 1 time (removed after first call)" << std::endl; - std::cout << " - Hook 3: called 3 times (never removed)" << std::endl; - } -} - -// ============================================================================ -// Main -// ============================================================================ -int main(int argc, char *argv[]) { - google::InitGoogleLogging(argv[0]); - nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); - - std::cout << "========================================" << std::endl; - std::cout << " Hook Mechanism Tests" << std::endl; - std::cout << "========================================" << std::endl; - - test_basic_hooks(); - test_hook_remove(); - - std::cout << "\n========================================" << std::endl; - std::cout << " All Tests Completed Successfully" << std::endl; - std::cout << "========================================" << std::endl; - - return 0; -} diff --git a/test/hook/test_precision_check.cc b/test/hook/test_precision_check.cc deleted file mode 100644 index 65c8258c..00000000 --- a/test/hook/test_precision_check.cc +++ /dev/null @@ -1,241 +0,0 @@ -#include -#include -#include - -#include "glog/logging.h" - -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/tensor.h" -#include "infini_train/include/utils/global_module_hook_registry.h" -#include "infini_train/include/utils/precision_check_config.h" -#include "infini_train/include/utils/precision_checker.h" - -using namespace infini_train; - -class MyModel : public nn::Module { -public: - MyModel() : Module("MyModel") {} - - std::vector> Forward(const std::vector> &inputs) override { - auto x = inputs[0]; - x->RequiresGrad(); - auto y = x->Mul(x); - return {y}; - } -}; - -// Simple model for multi-iteration test -class SimpleModel : public nn::Module { -public: - SimpleModel() : Module("SimpleModel") {} - - std::vector> Forward(const std::vector> &inputs) override { - auto x = inputs[0]; - x->RequiresGrad(); - auto y = x->Mul(x)->Mul(x); // x^3 - return {y}; - } -}; - -void RunModelForwardBackward(const std::shared_ptr &model) { - auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); - x->Fill(2.0f); - x->RequiresGrad(); - - std::vector> inputs = {x}; - auto outputs = (*model)(inputs); - auto loss = outputs[0]->Sum(0, false)->Sum(0, false); - loss->Backward(); -} - -void TestFunctionLevel(const std::string &config_str) { - std::cout << "\n========================================" << std::endl; - std::cout << " Function-Level Test: " << config_str << std::endl; - std::cout << "========================================" << std::endl; - - auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); - x->Fill(2.0f); - x->RequiresGrad(); - - auto y = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); - y->Fill(3.0f); - y->RequiresGrad(); - - auto z = x->Mul(y); - auto loss = z->Sum(0, false)->Sum(0, false); - loss->Backward(); - - std::cout << "Test completed." << std::endl; -} - -void TestModuleLevel(const std::string &config_str) { - std::cout << "\n========================================" << std::endl; - std::cout << " Module-Level Test: " << config_str << std::endl; - std::cout << "========================================" << std::endl; - - auto model = std::make_shared(); - RunModelForwardBackward(model); - - std::cout << "Test completed." << std::endl; -} - -// Test: Simple format output (level=2, format=simple) -void TestSimpleFormat() { - std::cout << "\n========================================" << std::endl; - std::cout << " Test: Simple Format (level=2, format=simple)" << std::endl; - std::cout << "========================================" << std::endl; - - auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); - x->Fill(2.0f); - x->RequiresGrad(); - - auto y = x->Mul(x); - auto loss = y->Sum(0, false)->Sum(0, false); // Two Sum ops to produce scalar - loss->Backward(); - - std::cout << "Simple format test completed - check output for min/max/mean values." << std::endl; -} - -// Test: MD5 format output (level=2, format=md5) -void TestMd5Format() { - std::cout << "\n========================================" << std::endl; - std::cout << " Test: MD5 Format (level=2, format=md5)" << std::endl; - std::cout << "========================================" << std::endl; - - auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); - x->Fill(2.0f); - x->RequiresGrad(); - - auto y = x->Mul(x); - auto loss = y->Sum(0, false)->Sum(0, false); // Two Sum ops to produce scalar - loss->Backward(); - - std::cout << "MD5 format test completed - check output for md5 hashes." << std::endl; -} - -// Test: Save tensors to NPY files (level=1, save_tensors=true) -void TestSaveTensors() { - std::cout << "\n========================================" << std::endl; - std::cout << " Test: Save Tensors (level=1, save_tensors=true)" << std::endl; - std::cout << "========================================" << std::endl; - - std::string output_path = "/tmp/precision_check_npy"; - - auto model = std::make_shared(); - RunModelForwardBackward(model); - - // Verify NPY files were created - namespace fs = std::filesystem; - bool found_npy = false; - if (fs::exists(output_path)) { - for (const auto &entry : fs::recursive_directory_iterator(output_path)) { - if (entry.path().extension() == ".npy") { - found_npy = true; - std::cout << "Found NPY file: " << entry.path() << std::endl; - } - } - } - - if (found_npy) { - std::cout << "Save tensors test PASSED - NPY files created successfully." << std::endl; - } else { - std::cout << "Save tensors test completed - check output directory for NPY files." << std::endl; - } -} - -// Test: Multi-iteration file overwrite (level=1, save_tensors=true, iter=3) -void TestMultiIterOverwrite() { - std::cout << "\n========================================" << std::endl; - std::cout << " Test: Multi-Iteration File Overwrite" << std::endl; - std::cout << "========================================" << std::endl; - - std::string output_path = "/tmp/precision_check_overwrite"; - - auto model = std::make_shared(); - int num_iters = 3; - - // Run multiple iterations - files should be overwritten - for (int i = 0; i < num_iters; ++i) { - std::cout << "Iteration " << (i + 1) << "/" << num_iters << std::endl; - utils::PrecisionCheckEnv::ResetCounters(); // Reset counters each iteration - RunModelForwardBackward(model); - } - - namespace fs = std::filesystem; - int npy_count = 0; - if (fs::exists(output_path)) { - for (const auto &entry : fs::recursive_directory_iterator(output_path)) { - if (entry.path().extension() == ".npy") { - ++npy_count; - } - } - } - - std::cout << "Multi-iteration test completed - found " << npy_count << " NPY files after " << num_iters - << " iterations." << std::endl; - std::cout << "(Files should be overwritten each iteration, count should be consistent with 1 iter)" << std::endl; -} - -int main(int argc, char *argv[]) { - google::InitGoogleLogging(argv[0]); - - std::string config_str = argc > 1 ? argv[1] : ""; - - std::cout << "========================================" << std::endl; - std::cout << " Precision Check Test Suite" << std::endl; - std::cout << "========================================" << std::endl; - - nn::parallel::global::InitAllEnv(1, 1, false, 1, 1); - - // If no config argument, run all format tests - if (config_str.empty()) { - auto config = utils::PrecisionCheckConfig::Parse("level=2,format=simple"); - utils::PrecisionCheckEnv::Instance().Init(config); - - std::cout << "\nRunning all precision check format tests..." << std::endl; - - // Test 1: Simple format - TestSimpleFormat(); - - // Test 2: MD5 format - auto md5_config = utils::PrecisionCheckConfig::Parse("level=2,format=md5"); - utils::PrecisionCheckEnv::Instance().Init(md5_config); - TestMd5Format(); - - // Test 3: Save tensors - auto npy_config = utils::PrecisionCheckConfig::Parse("level=1,save_tensors=true"); - utils::PrecisionCheckEnv::Instance().Init(npy_config); - TestSaveTensors(); - - // Test 4: Multi-iteration overwrite - auto iter_config = utils::PrecisionCheckConfig::Parse("level=1,save_tensors=true"); - utils::PrecisionCheckEnv::Instance().Init(iter_config); - TestMultiIterOverwrite(); - - std::cout << "\n========================================" << std::endl; - std::cout << " All Tests Completed Successfully" << std::endl; - std::cout << "========================================" << std::endl; - return 0; - } - - // If config provided, run single test (original behavior) - auto config = utils::PrecisionCheckConfig::Parse(config_str); - utils::PrecisionCheckEnv::Instance().Init(config); - - std::cout << "Config: " << config_str << std::endl; - - if (config.level == utils::PrecisionCheckLevel::MODULE) { - TestModuleLevel(config_str); - } else if (config.level == utils::PrecisionCheckLevel::FUNCTION) { - TestFunctionLevel(config_str); - } else { - std::cout << "No tests to run (level=0)" << std::endl; - } - - std::cout << "\n========================================" << std::endl; - std::cout << " Test Completed" << std::endl; - std::cout << "========================================" << std::endl; - - return 0; -} diff --git a/test/lora/test_lora.cc b/test/lora/test_lora.cc deleted file mode 100644 index 06966809..00000000 --- a/test/lora/test_lora.cc +++ /dev/null @@ -1,860 +0,0 @@ -#include -#include -#include - -#include "glog/logging.h" - -#include "infini_train/include/nn/lora/lora_config.h" -#include "infini_train/include/nn/lora/lora_linear.h" -#include "infini_train/include/nn/lora/lora_utils.h" -#include "infini_train/include/nn/modules/container.h" -#include "infini_train/include/nn/modules/linear.h" -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/tensor.h" - -using namespace infini_train; -using namespace infini_train::nn::lora; - -// ============================================================================ -// Test 1: LoRAConfig -// ============================================================================ -void test_lora_config() { - std::cout << "\n=== Test 1: LoRAConfig ===" << std::endl; - - LoRAConfig config; - config.rank = 8; - config.alpha = 16.0f; - - // Test scaling calculation - float expected_scaling = 16.0f / 8.0f; - CHECK_EQ(config.Scaling(), expected_scaling) << "Scaling calculation failed"; - std::cout << "Scaling: " << config.Scaling() << " (expected: " << expected_scaling << ")" << std::endl; - - // Test ShouldApplyLoRA - CHECK(config.ShouldApplyLoRA("c_attn")) << "Should match c_attn"; - CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_attn")) << "Should match nested c_attn"; - CHECK(config.ShouldApplyLoRA("c_proj")) << "Should match c_proj"; - CHECK(!config.ShouldApplyLoRA("c_fc")) << "Should not match c_fc (not in default targets)"; - CHECK(!config.ShouldApplyLoRA("random_layer")) << "Should not match random_layer"; - - std::cout << "LoRAConfig tests passed!" << std::endl; -} - -// ============================================================================ -// Test 2: LoRALinear Initialization -// ============================================================================ -void test_lora_linear_init() { - std::cout << "\n=== Test 2: LoRALinear Initialization ===" << std::endl; - - LoRAConfig config; - config.rank = 4; - config.alpha = 8.0f; - - int64_t in_features = 64; - int64_t out_features = 128; - - auto lora_linear - = std::shared_ptr(new LoRALinear(in_features, out_features, config, /*bias=*/true, nullptr)); - - // Check parameter shapes - auto weight = lora_linear->parameter(nn::Linear::kParamWeightName); - auto bias = lora_linear->parameter(nn::Linear::kParamBiasName); - auto lora_A = lora_linear->parameter(LoRALinear::kParamLoraAName); - auto lora_B = lora_linear->parameter(LoRALinear::kParamLoraBName); - - CHECK_EQ(weight->Dims().size(), 2); - CHECK_EQ(weight->Dims()[0], out_features); - CHECK_EQ(weight->Dims()[1], in_features); - std::cout << "Weight shape: [" << weight->Dims()[0] << ", " << weight->Dims()[1] << "]" << std::endl; - - CHECK_EQ(bias->Dims().size(), 1); - CHECK_EQ(bias->Dims()[0], out_features); - std::cout << "Bias shape: [" << bias->Dims()[0] << "]" << std::endl; - - CHECK_EQ(lora_A->Dims().size(), 2); - CHECK_EQ(lora_A->Dims()[0], config.rank); - CHECK_EQ(lora_A->Dims()[1], in_features); - std::cout << "LoRA A shape: [" << lora_A->Dims()[0] << ", " << lora_A->Dims()[1] << "]" << std::endl; - - CHECK_EQ(lora_B->Dims().size(), 2); - CHECK_EQ(lora_B->Dims()[0], out_features); - CHECK_EQ(lora_B->Dims()[1], config.rank); - std::cout << "LoRA B shape: [" << lora_B->Dims()[0] << ", " << lora_B->Dims()[1] << "]" << std::endl; - - // Check requires_grad - CHECK(!weight->requires_grad()) << "Base weight should be frozen"; - CHECK(!bias->requires_grad()) << "Base bias should be frozen"; - CHECK(lora_A->requires_grad()) << "LoRA A should be trainable"; - CHECK(lora_B->requires_grad()) << "LoRA B should be trainable"; - std::cout << "requires_grad check passed!" << std::endl; - - // Check LoRAParameters() returns only LoRA params - auto params = lora_linear->LoRAParameters(); - CHECK_EQ(params.size(), 2) << "LoRAParameters() should return only LoRA params"; - std::cout << "LoRAParameters() returns " << params.size() << " tensors (LoRA A and B)" << std::endl; - - std::cout << "LoRALinear initialization tests passed!" << std::endl; -} - -// ============================================================================ -// Test 3: LoRALinear Forward Pass -// ============================================================================ -void test_lora_linear_forward() { - std::cout << "\n=== Test 3: LoRALinear Forward Pass ===" << std::endl; - - LoRAConfig config; - config.rank = 4; - config.alpha = 8.0f; - - int64_t in_features = 64; - int64_t out_features = 128; - int64_t batch_size = 2; - int64_t seq_len = 10; - - auto lora_linear - = std::shared_ptr(new LoRALinear(in_features, out_features, config, /*bias=*/true, nullptr)); - - // Create input tensor - auto input = std::make_shared(std::vector{batch_size, seq_len, in_features}, DataType::kFLOAT32); - - // Forward pass - auto output = (*lora_linear)({input})[0]; - - // Check output shape - CHECK_EQ(output->Dims().size(), 3); - CHECK_EQ(output->Dims()[0], batch_size); - CHECK_EQ(output->Dims()[1], seq_len); - CHECK_EQ(output->Dims()[2], out_features); - std::cout << "Output shape: [" << output->Dims()[0] << ", " << output->Dims()[1] << ", " << output->Dims()[2] << "]" - << std::endl; - - std::cout << "LoRALinear forward pass tests passed!" << std::endl; -} - -// ============================================================================ -// Test 4: LoRALinear Weight Merging -// ============================================================================ -void test_lora_linear_merge() { - std::cout << "\n=== Test 4: LoRALinear Weight Merging ===" << std::endl; - - LoRAConfig config; - config.rank = 4; - config.alpha = 8.0f; - - int64_t in_features = 32; - int64_t out_features = 64; - - auto lora_linear - = std::shared_ptr(new LoRALinear(in_features, out_features, config, /*bias=*/false, nullptr)); - - // Print weight sum before merge - auto weight_before = lora_linear->parameter(nn::Linear::kParamWeightName); - auto lora_A = lora_linear->parameter(LoRALinear::kParamLoraAName); - auto lora_B = lora_linear->parameter(LoRALinear::kParamLoraBName); - - float weight_before_sum = weight_before->EigenMatrix().sum(); - float lora_A_sum = lora_A->EigenMatrix().sum(); - float lora_B_sum = lora_B->EigenMatrix().sum(); - - std::cout << "\n--- Before Merge ---" << std::endl; - std::cout << "Base weight sum: " << weight_before_sum << std::endl; - std::cout << "LoRA A sum: " << lora_A_sum << std::endl; - std::cout << "LoRA B sum: " << lora_B_sum << std::endl; - std::cout << "Scaling (alpha/r): " << config.Scaling() << std::endl; - - // Create input - auto input = std::make_shared(std::vector{2, 5, in_features}, DataType::kFLOAT32); - input->EigenMatrix().setRandom(); - - // Get output before merge - auto output_before = (*lora_linear)({input})[0]; - float output_before_sum = output_before->EigenMatrix().sum(); - std::cout << "Output sum before merge: " << output_before_sum << std::endl; - - // Merge weights - CHECK(!lora_linear->IsMerged()) << "Should not be merged initially"; - lora_linear->MergeWeights(); - CHECK(lora_linear->IsMerged()) << "Should be merged after MergeWeights()"; - - // Verify LoRA params are frozen after merge - CHECK(!lora_A->requires_grad()) << "lora_A should be frozen after merge"; - CHECK(!lora_B->requires_grad()) << "lora_B should be frozen after merge"; - std::cout << "\nWeights merged successfully, LoRA params frozen" << std::endl; - - // Print weight sum after merge - auto weight_after = lora_linear->parameter(nn::Linear::kParamWeightName); - float weight_after_sum = weight_after->EigenMatrix().sum(); - std::cout << "\n--- After Merge ---" << std::endl; - std::cout << "Base weight sum after merge: " << weight_after_sum << std::endl; - std::cout << "Weight change (should be ~LoRA contribution): " << (weight_after_sum - weight_before_sum) - << std::endl; - - // Get output after merge - auto output_merged = (*lora_linear)({input})[0]; - float output_merged_sum = output_merged->EigenMatrix().sum(); - std::cout << "Output sum after merge: " << output_merged_sum << std::endl; - - // Verify: output_after should equal output_before (numerically) - std::cout << "\nVerification: output_before == output_after? " << std::endl; - std::cout << " Before: " << output_before_sum << std::endl; - std::cout << " After: " << output_merged_sum << std::endl; - std::cout << " Diff: " << std::abs(output_before_sum - output_merged_sum) << std::endl; - CHECK(std::abs(output_before_sum - output_merged_sum) < 1e-3) << "Outputs should be numerically identical!"; - - // Shape comparison (always same) - std::cout << "\nOutput shape: [" << output_before->Dims()[0] << ", " << output_before->Dims()[1] << ", " - << output_before->Dims()[2] << "] (unchanged)" << std::endl; - - // Unmerge weights - lora_linear->UnmergeWeights(); - CHECK(!lora_linear->IsMerged()) << "Should not be merged after UnmergeWeights()"; - - // Verify LoRA params are trainable again after unmerge - CHECK(lora_A->requires_grad()) << "lora_A should be trainable after unmerge"; - CHECK(lora_B->requires_grad()) << "lora_B should be trainable after unmerge"; - - // Print weight sum after unmerge - auto weight_unmerged = lora_linear->parameter(nn::Linear::kParamWeightName); - float weight_unmerged_sum = weight_unmerged->EigenMatrix().sum(); - std::cout << "\n--- After Unmerge ---" << std::endl; - std::cout << "Base weight sum after unmerge: " << weight_unmerged_sum << std::endl; - - // Verify: weight should be restored to original value - std::cout << "\nVerification: weight restored after unmerge? " << std::endl; - std::cout << " Original: " << weight_before_sum << std::endl; - std::cout << " Unmerged: " << weight_unmerged_sum << std::endl; - std::cout << " Diff: " << std::abs(weight_before_sum - weight_unmerged_sum) << std::endl; - CHECK(std::abs(weight_before_sum - weight_unmerged_sum) < 1e-4) << "Weight should be restored!"; - - // Get output after unmerge - auto output_unmerged = (*lora_linear)({input})[0]; - float output_unmerged_sum = output_unmerged->EigenMatrix().sum(); - std::cout << "Output sum after unmerge: " << output_unmerged_sum << std::endl; - - // Shape comparison: merge doesn't change shape, only weights - CHECK(output_before->Dims() == output_merged->Dims()) << "Shape should be identical after merge"; - CHECK(output_merged->Dims() == output_unmerged->Dims()) << "Shape should be identical after unmerge"; - - std::cout << "\nLoRALinear weight merging tests passed!" << std::endl; -} - -// ============================================================================ -// Test 5: LoRA Utility Functions -// ============================================================================ -void test_lora_utils() { - std::cout << "\n=== Test 5: LoRA Utility Functions ===" << std::endl; - - LoRAConfig config; - config.rank = 4; - config.alpha = 8.0f; - - auto lora_linear = std::shared_ptr(new LoRALinear(32, 64, config, /*bias=*/true, nullptr)); - - // Test GetLoRAParameters - auto lora_params = GetLoRAParameters(lora_linear); - CHECK_EQ(lora_params.size(), 2) << "Should have 2 LoRA parameters"; - std::cout << "GetLoRAParameters returned " << lora_params.size() << " parameters" << std::endl; - - // Test CountTrainableParameters - int64_t trainable = CountTrainableParameters(lora_linear); - int64_t expected_trainable = config.rank * 32 + 64 * config.rank; // A: [4, 32], B: [64, 4] - CHECK_EQ(trainable, expected_trainable) << "Trainable parameter count mismatch"; - std::cout << "Trainable parameters: " << trainable << " (expected: " << expected_trainable << ")" << std::endl; - - // Test CountTotalParameters - int64_t total = CountTotalParameters(lora_linear); - int64_t expected_total = 64 * 32 + 64 + config.rank * 32 + 64 * config.rank; // weight + bias + A + B - CHECK_EQ(total, expected_total) << "Total parameter count mismatch"; - std::cout << "Total parameters: " << total << " (expected: " << expected_total << ")" << std::endl; - - // Test PrintLoRASummary - std::cout << "\nLoRA Summary:" << std::endl; - PrintLoRASummary(lora_linear); - - std::cout << "LoRA utility function tests passed!" << std::endl; -} - -// ============================================================================ -// Test 6: LoRALinear from existing Linear -// ============================================================================ -void test_lora_from_linear() { - std::cout << "\n=== Test 6: LoRALinear from existing Linear ===" << std::endl; - - // Create a standard Linear layer - auto linear = std::make_shared(64, 128, /*bias=*/true); - - // Wrap it with LoRA - LoRAConfig config; - config.rank = 8; - config.alpha = 16.0f; - - auto lora_linear = std::make_shared(linear, config); - - // Check dimensions - CHECK_EQ(lora_linear->in_features(), 64); - CHECK_EQ(lora_linear->out_features(), 128); - CHECK_EQ(lora_linear->rank(), 8); - std::cout << "LoRALinear created from Linear: in=" << lora_linear->in_features() - << ", out=" << lora_linear->out_features() << ", rank=" << lora_linear->rank() << std::endl; - - // Test forward pass - auto input = std::make_shared(std::vector{2, 10, 64}, DataType::kFLOAT32); - auto output = (*lora_linear)({input})[0]; - - CHECK_EQ(output->Dims()[0], 2); - CHECK_EQ(output->Dims()[1], 10); - CHECK_EQ(output->Dims()[2], 128); - std::cout << "Forward pass successful, output shape: [" << output->Dims()[0] << ", " << output->Dims()[1] << ", " - << output->Dims()[2] << "]" << std::endl; - - std::cout << "LoRALinear from existing Linear tests passed!" << std::endl; -} - -// ============================================================================ -// Test 7: LoRALinear from existing Linear (tests LoRA utilities) -// ============================================================================ -void test_lora_model_wrapper() { - std::cout << "\n=== Test 7: LoRALinear from existing Linear ===" << std::endl; - - // Create LoRA config - LoRAConfig lora_config; - lora_config.rank = 8; - lora_config.alpha = 16.0f; - - // Create base Linear module (simple test without InjectLoRALayers) - auto base_linear = std::make_shared(64, 128, /*bias=*/true); - - // Create a minimal wrapper test by manually testing what LoRAModel does - // Apply LoRA directly to the Linear layer - auto lora_linear = std::make_shared(base_linear, lora_config); - - // Replace the base_linear in its container - // Note: In a real use case, you would use InjectLoRALayers on a transformer model - - // Test GetLoRAParameters on the LoRA Linear - auto lora_params = GetLoRAParameters(lora_linear); - CHECK_GT(lora_params.size(), 0) << "Should have trainable parameters"; - std::cout << "LoRA parameters extracted: " << lora_params.size() << std::endl; - - // Test CountTrainableParameters - int64_t trainable = CountTrainableParameters(lora_linear); - CHECK_EQ(trainable, lora_config.rank * 64 + 128 * lora_config.rank); - std::cout << "Trainable parameters: " << trainable << std::endl; - - // Test PrintSummary - std::cout << "\nLoRA Summary for Linear wrapper:" << std::endl; - PrintLoRASummary(lora_linear); - - // Test Save/Load LoRA on the LoRA Linear - const std::string test_path = "/tmp/test_lora_linear.bin"; - SaveLoRAWeights(lora_linear, test_path); - std::cout << "SaveLoRAWeights completed" << std::endl; - - LoadLoRAWeights(lora_linear, test_path); - std::cout << "LoadLoRAWeights completed" << std::endl; - - // Test Merge/Unmerge on LoRA Linear - CHECK(!lora_linear->IsMerged()) << "Should not be merged initially"; - lora_linear->MergeWeights(); - CHECK(lora_linear->IsMerged()) << "Should be merged after MergeWeights()"; - std::cout << "MergeWeights completed" << std::endl; - - lora_linear->UnmergeWeights(); - CHECK(!lora_linear->IsMerged()) << "Should be unmerged after UnmergeWeights()"; - std::cout << "UnmergeWeights completed" << std::endl; - - std::cout << "LoRALinear utility tests passed!" << std::endl; -} - -// ============================================================================ -// Test 8: Save/Load LoRA Weights -// ============================================================================ -void test_lora_save_load_weights() { - std::cout << "\n=== Test 8: Save/Load LoRA Weights ===" << std::endl; - - // Create a LoRALinear - LoRAConfig config; - config.rank = 4; - config.alpha = 8.0f; - - int64_t in_features = 32; - int64_t out_features = 64; - - auto linear = std::make_shared(in_features, out_features, /*bias=*/true); - auto lora_linear = std::make_shared(linear, config); - - // Get references to lora_A and lora_B - auto lora_A = lora_linear->parameter(LoRALinear::kParamLoraAName); - auto lora_B = lora_linear->parameter(LoRALinear::kParamLoraBName); - - // Set specific values to lora_A and lora_B - // lora_A: [rank, in_features] = [4, 32] - // lora_B: [out_features, rank] = [64, 4] - lora_A->EigenMatrix().setZero(); - lora_B->EigenMatrix().setZero(); - - // Set lora_A to all 1s - for (int64_t i = 0; i < lora_A->Dims()[0]; ++i) { - for (int64_t j = 0; j < lora_A->Dims()[1]; ++j) { lora_A->EigenMatrix()(i, j) = 1.0f; } - } - - // Set lora_B to all 2s - for (int64_t i = 0; i < lora_B->Dims()[0]; ++i) { - for (int64_t j = 0; j < lora_B->Dims()[1]; ++j) { lora_B->EigenMatrix()(i, j) = 2.0f; } - } - - // Record original sums - float lora_A_sum_orig = lora_A->EigenMatrix().sum(); - float lora_B_sum_orig = lora_B->EigenMatrix().sum(); - // lora_A: all 1.0f, shape [rank, in_features] = [4, 32] - // lora_B: all 2.0f, shape [out_features, rank] = [64, 4] - float expected_lora_A_sum = config.rank * in_features * 1.0f; // 4 * 32 * 1 = 128 - float expected_lora_B_sum = out_features * config.rank * 2.0f; // 64 * 4 * 2 = 512 - std::cout << "Original lora_A sum: " << lora_A_sum_orig << " (expected: " << expected_lora_A_sum << ")" - << std::endl; - std::cout << "Original lora_B sum: " << lora_B_sum_orig << " (expected: " << expected_lora_B_sum << ")" - << std::endl; - - CHECK_EQ(lora_A_sum_orig, expected_lora_A_sum); - CHECK_EQ(lora_B_sum_orig, expected_lora_B_sum); - - // Save to file - const std::string test_path = "/tmp/test_lora_save_load.bin"; - SaveLoRAWeights(lora_linear, test_path); - std::cout << "Saved LoRA weights to: " << test_path << std::endl; - - // Modify weights to different values - lora_A->EigenMatrix().setConstant(9.0f); - lora_B->EigenMatrix().setConstant(9.0f); - - float lora_A_sum_modified = lora_A->EigenMatrix().sum(); - float lora_B_sum_modified = lora_B->EigenMatrix().sum(); - std::cout << "Modified lora_A sum: " << lora_A_sum_modified << std::endl; - std::cout << "Modified lora_B sum: " << lora_B_sum_modified << std::endl; - - CHECK_NE(lora_A_sum_modified, lora_A_sum_orig); - CHECK_NE(lora_B_sum_modified, lora_B_sum_orig); - - // Load from file - LoadLoRAWeights(lora_linear, test_path); - std::cout << "Loaded LoRA weights from: " << test_path << std::endl; - - // Verify weights are restored - float lora_A_sum_loaded = lora_A->EigenMatrix().sum(); - float lora_B_sum_loaded = lora_B->EigenMatrix().sum(); - std::cout << "Loaded lora_A sum: " << lora_A_sum_loaded << std::endl; - std::cout << "Loaded lora_B sum: " << lora_B_sum_loaded << std::endl; - - CHECK_EQ(lora_A_sum_loaded, lora_A_sum_orig) << "lora_A should be restored to original values"; - CHECK_EQ(lora_B_sum_loaded, lora_B_sum_orig) << "lora_B should be restored to original values"; - - // Also verify individual elements - for (int64_t i = 0; i < lora_A->Dims()[0]; ++i) { - for (int64_t j = 0; j < lora_A->Dims()[1]; ++j) { - CHECK_EQ(lora_A->EigenMatrix()(i, j), 1.0f) << "lora_A element mismatch at (" << i << "," << j << ")"; - } - } - - for (int64_t i = 0; i < lora_B->Dims()[0]; ++i) { - for (int64_t j = 0; j < lora_B->Dims()[1]; ++j) { - CHECK_EQ(lora_B->EigenMatrix()(i, j), 2.0f) << "lora_B element mismatch at (" << i << "," << j << ")"; - } - } - - std::cout << "All elements verified correctly!" << std::endl; - - // Cleanup - std::remove(test_path.c_str()); - std::cout << "Test 8: Save/Load LoRA Weights passed!" << std::endl; -} - -// ============================================================================ -// Test 8: ParseLoRATargetModules parsing -// ============================================================================ -void test_set_target_modules() { - std::cout << "\n=== Test 8: ParseLoRATargetModules Parsing ===" << std::endl; - - // Test single target - auto modules = ParseLoRATargetModules("c_attn"); - CHECK_EQ(modules.size(), 1); - CHECK(modules.count("c_attn")); - std::cout << "Single target: OK" << std::endl; - - // Test multiple targets - modules = ParseLoRATargetModules("c_attn,c_proj,c_fc"); - CHECK_EQ(modules.size(), 3); - CHECK(modules.count("c_attn")); - CHECK(modules.count("c_proj")); - CHECK(modules.count("c_fc")); - std::cout << "Multiple targets: OK" << std::endl; - - // Test with spaces - modules = ParseLoRATargetModules("c_attn, c_proj , c_fc"); - CHECK_EQ(modules.size(), 3); - std::cout << "Targets with spaces: OK" << std::endl; - - // Test empty/whitespace - modules = ParseLoRATargetModules("c_attn,,c_proj"); - CHECK_EQ(modules.size(), 2); - std::cout << "Empty entries ignored: OK" << std::endl; - - std::cout << "ParseLoRATargetModules tests passed!" << std::endl; -} - -// ============================================================================ -// Test 9: ShouldApplyLoRA edge cases (attn.c_proj vs mlp.c_proj) -// ============================================================================ -void test_should_apply_lora_edge_cases() { - std::cout << "\n=== Test 9: ShouldApplyLoRA Edge Cases ===" << std::endl; - - // Test: Only attn.c_proj in target_modules - { - LoRAConfig config{8, 16.0f, 0.0f, ParseLoRATargetModules("c_attn,attn.c_proj")}; - - // Should match attention paths - CHECK(config.ShouldApplyLoRA("attn.c_proj")); - CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); - CHECK(config.ShouldApplyLoRA("transformer.h.1.attn.c_proj")); - - // Should NOT match mlp paths - CHECK(!config.ShouldApplyLoRA("mlp.c_proj")); - CHECK(!config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); - std::cout << "attn.c_proj only: OK" << std::endl; - } - - // Test: Only mlp.c_proj in target_modules - { - LoRAConfig config{8, 16.0f, 0.0f, ParseLoRATargetModules("c_attn,mlp.c_proj")}; - - // Should NOT match attention paths - CHECK(!config.ShouldApplyLoRA("attn.c_proj")); - CHECK(!config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); - - // Should match mlp paths - CHECK(config.ShouldApplyLoRA("mlp.c_proj")); - CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); - std::cout << "mlp.c_proj only: OK" << std::endl; - } - - // Test: Generic c_proj in target_modules (matches both) - { - LoRAConfig config{8, 16.0f, 0.0f, ParseLoRATargetModules("c_attn,c_proj")}; - - // Should match both attention and mlp - CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); - CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); - std::cout << "Generic c_proj (matches both): OK" << std::endl; - } - - // Test: All targets - { - LoRAConfig config{8, 16.0f, 0.0f, ParseLoRATargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj")}; - - CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_attn")); - CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); - CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_fc")); - CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_fc2")); - CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); - std::cout << "All targets: OK" << std::endl; - } - - std::cout << "ShouldApplyLoRA edge cases tests passed!" << std::endl; -} - -// ============================================================================ -// Test 10: ReplaceModuleByPath -// ============================================================================ -void test_replace_module_by_path() { - std::cout << "\n=== Test 10: ReplaceModuleByPath ===" << std::endl; - - // Test ReplaceModuleByPath by wrapping a Linear with LoRA directly - // This tests the core functionality that ReplaceModuleByPath provides - - // Create base Linear - auto base_linear = std::make_shared(64, 128, /*bias=*/true); - - // Configure LoRA - LoRAConfig lora_config; - lora_config.rank = 4; - lora_config.alpha = 8.0f; - - // Wrap with LoRA - this is what ReplaceModuleByPath does internally - auto lora_linear = std::make_shared(base_linear, lora_config); - - // Verify LoRA was applied correctly - auto params = lora_linear->LoRAParameters(); - CHECK_EQ(params.size(), 2) << "LoRALinear should have 2 trainable parameters (lora_A and lora_B)"; - std::cout << "LoRALinear has " << params.size() << " trainable parameters" << std::endl; - - // Verify parameter shapes - auto lora_a = params[0]; - auto lora_b = params[1]; - CHECK_EQ(lora_a->Dims()[0], lora_config.rank); // rank x in_features - CHECK_EQ(lora_a->Dims()[1], 64); - CHECK_EQ(lora_b->Dims()[0], 128); // out_features x rank - CHECK_EQ(lora_b->Dims()[1], lora_config.rank); - std::cout << "LoRA parameter shapes: OK" << std::endl; - - // Verify base parameters are frozen (use named parameters instead of index) - auto weight = lora_linear->parameter(nn::Linear::kParamWeightName); - auto lora_a_param = lora_linear->parameter(LoRALinear::kParamLoraAName); - auto lora_b_param = lora_linear->parameter(LoRALinear::kParamLoraBName); - CHECK(weight != nullptr); - CHECK(lora_a_param != nullptr); - CHECK(lora_b_param != nullptr); - CHECK(!weight->requires_grad()); // weight is frozen - CHECK(lora_a_param->requires_grad()); // lora_A is trainable - CHECK(lora_b_param->requires_grad()); // lora_B is trainable - std::cout << "Base weight frozen, LoRA params trainable: OK" << std::endl; - - std::cout << "ReplaceModuleByPath tests passed!" << std::endl; -} - -// ============================================================================ -// Test 11: FreezeBaseModel / UnfreezeModel -// ============================================================================ -void test_freeze_unfreeze() { - std::cout << "\n=== Test 11: FreezeBaseModel / UnfreezeModel ===" << std::endl; - - // Test with LoRALinear directly - it has both base and LoRA params - LoRAConfig lora_config; - lora_config.rank = 4; - lora_config.alpha = 8.0f; - - auto linear = std::make_shared(64, 128, /*bias=*/true); - auto lora_linear = std::make_shared(linear, lora_config); - - // Get all parameters from LoRALinear (includes base + LoRA) - auto all_params = lora_linear->Parameters(); - - // Initially only LoRA params should be trainable (base weights are frozen by constructor) - int64_t total_params = 0; - for (const auto &p : all_params) { - if (p->requires_grad()) { - total_params += p->NumElements(); - } - } - // Expected: only LoRA params (lora_A + lora_B) = 4*64 + 128*4 = 256 + 512 = 768 - // Note: LoRALinear freezes base weights in constructor by design - int64_t expected_total = lora_config.rank * 64 + 128 * lora_config.rank; - CHECK_EQ(total_params, expected_total); - std::cout << "Initial trainable params: " << total_params << " (expected: " << expected_total << ")" << std::endl; - - // FreezeBaseModel on LoRALinear - FreezeBaseModel(lora_linear); - - // After freeze, only LoRA params should be trainable - int64_t after_freeze = 0; - for (const auto &p : all_params) { - if (p->requires_grad()) { - after_freeze += p->NumElements(); - } - } - // LoRA params: A (rank x in) + B (out x rank) = 4*64 + 128*4 = 256 + 512 = 768 - int64_t expected_lora = lora_config.rank * 64 + 128 * lora_config.rank; - CHECK_EQ(after_freeze, expected_lora); - std::cout << "After freeze trainable: " << after_freeze << " (expected: " << expected_lora << ")" << std::endl; - - // Unfreeze all - UnfreezeModel(lora_linear); - int64_t after_unfreeze = 0; - for (const auto &p : all_params) { - if (p->requires_grad()) { - after_unfreeze += p->NumElements(); - } - } - // Should be back to all params trainable (base + LoRA) - int64_t expected_after_unfreeze = 64 * 128 + 128 + lora_config.rank * 64 + 128 * lora_config.rank; - CHECK_EQ(after_unfreeze, expected_after_unfreeze); - std::cout << "After unfreeze trainable: " << after_unfreeze << std::endl; - - std::cout << "FreezeBaseModel / UnfreezeModel tests passed!" << std::endl; -} - -// ============================================================================ -// Test 12: LoRAStateDict -// ============================================================================ -void test_lora_state_dict() { - std::cout << "\n=== Test 12: LoRAStateDict ===" << std::endl; - - // Test with a single LoRALinear - LoRAConfig lora_config; - lora_config.rank = 4; - lora_config.alpha = 8.0f; - - auto linear = std::make_shared(64, 128, /*bias=*/true); - auto lora_linear = std::make_shared(linear, lora_config); - - // Get state dict - it contains all parameters with their names - auto state_dict = lora_linear->StateDict(); - - // Check that we have all expected parameters - CHECK(state_dict.count("weight")) << "Should have weight parameter"; - CHECK(state_dict.count("bias")) << "Should have bias parameter"; - CHECK(state_dict.count("lora_A")) << "Should have lora_A parameter"; - CHECK(state_dict.count("lora_B")) << "Should have lora_B parameter"; - std::cout << "State dict contains: weight, bias, lora_A, lora_B" << std::endl; - - // Verify LoRA parameters exist and are trainable - CHECK(state_dict.at("lora_A")->requires_grad()) << "lora_A should be trainable"; - CHECK(state_dict.at("lora_B")->requires_grad()) << "lora_B should be trainable"; - CHECK(!state_dict.at("weight")->requires_grad()) << "weight should be frozen"; - std::cout << "LoRA parameters are trainable, base weight is frozen: OK" << std::endl; - - // Verify shapes - CHECK_EQ(state_dict.at("lora_A")->Dims()[0], lora_config.rank); - CHECK_EQ(state_dict.at("lora_A")->Dims()[1], 64); - CHECK_EQ(state_dict.at("lora_B")->Dims()[0], 128); - CHECK_EQ(state_dict.at("lora_B")->Dims()[1], lora_config.rank); - std::cout << "LoRA parameter shapes: OK" << std::endl; - - std::cout << "LoRAStateDict tests passed!" << std::endl; -} - -// ============================================================================ -// Test 13: GetLoRAModel simplified API -// ============================================================================ -void test_get_lora_model() { - std::cout << "\n=== Test 13: GetLoRAModel Simplified API ===" << std::endl; - - // Test GetLoRAModel with a simple Linear layer - // We'll wrap it with LoRA directly and verify the wrapper works - - // Create base Linear - auto base_linear = std::make_shared(64, 128, /*bias=*/true); - - // Configure LoRA - LoRAConfig config{4, 8.0f, 0.0f, ParseLoRATargetModules("Linear")}; - - // Use GetLoRAModel with the linear as the "model" - // Note: GetLoRAModel returns the modified model (in-place injection) - auto model = GetLoRAModel(base_linear, config); - - CHECK(model != nullptr); - std::cout << "GetLoRAModel returned valid pointer" << std::endl; - - // Test that LoRA was applied - check trainable parameters - auto lora_params = GetLoRAParameters(model); - // GetLoRAParameters returns vector>, size() is the count of tensors - // LoRALinear has 2 trainable tensors: lora_A (rank x in) and lora_B (out x rank) - CHECK_EQ(lora_params.size(), 2); - std::cout << "Trainable parameter tensors: " << lora_params.size() << " (expected: 2)" << std::endl; - - // Also verify total element count - int64_t total_elements = 0; - for (const auto &t : lora_params) { total_elements += t->NumElements(); } - int64_t expected_elements = config.rank * 64 + 128 * config.rank; // 768 - CHECK_EQ(total_elements, expected_elements); - std::cout << "Total trainable elements: " << total_elements << " (expected: " << expected_elements << ")" - << std::endl; - - // Test PrintSummary - std::cout << "\nLoRA Model Summary:" << std::endl; - PrintLoRASummary(model); - - // Test Merge/Unmerge using utility functions - MergeLoRAWeights(model); - // Verify LoRA params frozen after merge - auto *lora_mod = dynamic_cast(model.get()); - CHECK(lora_mod != nullptr); - CHECK(!lora_mod->LoRAParameters()[0]->requires_grad()) << "lora_A should be frozen after merge"; - CHECK(!lora_mod->LoRAParameters()[1]->requires_grad()) << "lora_B should be frozen after merge"; - std::cout << "Merge: OK (LoRA params frozen)" << std::endl; - - UnmergeLoRAWeights(model); - CHECK(lora_mod->LoRAParameters()[0]->requires_grad()) << "lora_A should be trainable after unmerge"; - CHECK(lora_mod->LoRAParameters()[1]->requires_grad()) << "lora_B should be trainable after unmerge"; - std::cout << "Unmerge: OK (LoRA params trainable)" << std::endl; - - std::cout << "GetLoRAModel in-place injection tests passed!" << std::endl; -} - -// ============================================================================ -// Test 14: MergeAndUnload -// ============================================================================ -void test_merge_and_unload() { - std::cout << "\n=== Test 14: MergeAndUnload ===" << std::endl; - - // Create base Linear and apply LoRA - auto base_linear = std::make_shared(64, 128, /*bias=*/true); - LoRAConfig config{4, 8.0f, 0.0f, ParseLoRATargetModules("Linear")}; - auto model = GetLoRAModel(base_linear, config); - - // Verify it's a LoRA module - CHECK(dynamic_cast(model.get()) != nullptr) << "Should be LoRALinear"; - - // Create input and get output before merge_and_unload - auto input = std::make_shared(std::vector{2, 5, 64}, DataType::kFLOAT32); - input->EigenMatrix().setRandom(); - auto output_before = (*model)({input})[0]; - float output_before_sum = output_before->EigenMatrix().sum(); - std::cout << "Output sum before MergeAndUnload: " << output_before_sum << std::endl; - - // MergeAndUnload - auto unloaded_model = MergeAndUnload(model); - CHECK(unloaded_model != nullptr) << "MergeAndUnload should return valid model"; - - // Verify it's no longer a LoRA module - CHECK(dynamic_cast(unloaded_model.get()) == nullptr) << "Should be plain Linear after MergeAndUnload"; - std::cout << "Model is no longer LoRALinear: OK" << std::endl; - - // Verify no LoRA parameters exist (check state dict) - auto state_dict = unloaded_model->StateDict(); - for (const auto &[name, param] : state_dict) { - CHECK(name.find("lora_A") == std::string::npos && name.find("lora_B") == std::string::npos) - << "Should not have LoRA parameters after MergeAndUnload, found: " << name; - } - std::cout << "No LoRA parameters in state dict: OK" << std::endl; - - // Verify forward output matches (merged output should equal unmerged LoRA output) - auto output_after = (*unloaded_model)({input})[0]; - float output_after_sum = output_after->EigenMatrix().sum(); - std::cout << "Output sum after MergeAndUnload: " << output_after_sum << std::endl; - std::cout << "Diff: " << std::abs(output_before_sum - output_after_sum) << std::endl; - CHECK(std::abs(output_before_sum - output_after_sum) < 1e-3) << "Output should match after MergeAndUnload"; - - // Verify all parameters have requires_grad = true (unfrozen) - for (const auto ¶m : unloaded_model->Parameters()) { - CHECK(param->requires_grad()) << "All parameters should be trainable after MergeAndUnload"; - } - std::cout << "All parameters trainable: OK" << std::endl; - - std::cout << "MergeAndUnload tests passed!" << std::endl; -} - -int main(int argc, char **argv) { - google::InitGoogleLogging(argv[0]); - FLAGS_logtostderr = 1; - - // Initialize parallel settings (required for some tensor operations) - // Parameters: nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, - // pipeline_parallel_size, virtual_pipeline_parallel_size - nn::parallel::global::InitAllEnv(1, 1, false, 1, 1); - - std::cout << "========================================" << std::endl; - std::cout << " LoRA Module Unit Tests " << std::endl; - std::cout << "========================================" << std::endl; - - test_lora_config(); - test_lora_linear_init(); - test_lora_linear_forward(); - test_lora_linear_merge(); - test_lora_utils(); - test_lora_from_linear(); - test_lora_model_wrapper(); - test_lora_save_load_weights(); - test_set_target_modules(); - test_should_apply_lora_edge_cases(); - test_replace_module_by_path(); - test_freeze_unfreeze(); - test_lora_state_dict(); - test_get_lora_model(); - test_merge_and_unload(); - - std::cout << "\n========================================" << std::endl; - std::cout << " All LoRA Tests Passed! " << std::endl; - std::cout << "========================================" << std::endl; - - return 0; -} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 00000000..39a44f27 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,24 @@ +# Tests CMakeLists.txt +# This file manages the test infrastructure for InfiniTrain + +# Include shared test macros (must be before any test subdirectory) +include(${CMAKE_CURRENT_SOURCE_DIR}/common/test_macros.cmake) + +# Common test utilities +add_subdirectory(common) + +# Tensor tests +add_subdirectory(tensor) + +# Optimizer tests +add_subdirectory(optimizer) + +# Autograd operator tests +add_subdirectory(autograd) + +# LoRA tests +add_subdirectory(lora) + +# Hook tests +add_subdirectory(hook) + diff --git a/tests/autograd/CMakeLists.txt b/tests/autograd/CMakeLists.txt new file mode 100644 index 00000000..d321f629 --- /dev/null +++ b/tests/autograd/CMakeLists.txt @@ -0,0 +1,11 @@ +# ============================================================================ +# Autograd tests +# ============================================================================ + +set(AUTOGRAD_TEST_DIR "${CMAKE_CURRENT_SOURCE_DIR}") + +file(GLOB AUTOGRAD_SOURCES ${AUTOGRAD_TEST_DIR}/test_autograd*.cc) + +infini_train_add_test_suite(test_autograd + SOURCES ${AUTOGRAD_SOURCES} +) diff --git a/tests/autograd/test_autograd.cc b/tests/autograd/test_autograd.cc new file mode 100644 index 00000000..6401cc93 --- /dev/null +++ b/tests/autograd/test_autograd.cc @@ -0,0 +1,376 @@ +#include + +#include +#include + +#include "infini_train/include/autograd/activations.h" +#include "infini_train/include/autograd/elementwise.h" +#include "infini_train/include/autograd/function.h" +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/autograd/matmul.h" +#include "infini_train/include/autograd/misc.h" +#include "infini_train/include/autograd/normalization.h" +#include "infini_train/include/autograd/outer.h" +#include "infini_train/include/autograd/reduction.h" +#include "infini_train/include/autograd/softmax.h" +#include "infini_train/include/autograd/transform.h" +#include "infini_train/include/tensor.h" +#include "test_utils.h" + +using namespace infini_train; + +// ============================================================================ +// Forward / Backward — CPU + CUDA +// ============================================================================ + +class AutogradForwardTest : public infini_train::test::AutogradTestBaseP {}; +class AutogradBackwardTest : public infini_train::test::AutogradTestBaseP {}; + +TEST_P(AutogradForwardTest, AddForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 2.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 3})); +} + +TEST_P(AutogradForwardTest, SubForward) { + auto a = createTensor({2, 3}, 5.0f); + auto b = createTensor({2, 3}, 3.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, MulForward) { + auto a = createTensor({2, 3}, 2.0f); + auto b = createTensor({2, 3}, 3.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, DivForward) { + auto a = createTensor({2, 3}, 6.0f); + auto b = createTensor({2, 3}, 2.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, NegForward) { + auto a = createTensor({2, 3}, 5.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, SinForward) { + auto a = createTensor({2, 3}, 0.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, CosForward) { + auto a = createTensor({2, 3}, 0.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, TanhForward) { + auto a = createTensor({2, 3}, 0.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, ExpForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, LogForward) { + auto a = createTensor({2, 3}, 2.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, ReciprocalForward) { + auto a = createTensor({2, 3}, 2.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, PowForward) { + auto a = createTensor({2, 3}, 2.0f); + auto result = std::make_shared(2.0f)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, RsqrtForward) { + auto a = createTensor({2, 3}, 4.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, SigmoidForward) { + auto a = createTensor({2, 3}, 0.0f); + auto result = std::make_shared()->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, MatmulForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({3, 4}, 1.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 4})); +} + +TEST_P(AutogradForwardTest, SumForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(1, false)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, MeanForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(1, false)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, MaxForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(1, false)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, MinForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(1, false)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, SoftmaxForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(1)->Apply({a}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 3})); +} + +TEST_P(AutogradForwardTest, LayerNormForward) { + auto a = createTensor({2, 3, 4}, 1.0f); + auto weight = createTensor({4}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto result = std::make_shared(1e-5f)->Apply({a, weight, bias}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, LinearForward) { + auto input = createTensor({2, 3}, 1.0f); + auto weight = createTensor({4, 3}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto result = std::make_shared()->Apply({input, weight, bias}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 4})); +} + +TEST_P(AutogradForwardTest, TransposeForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(0, 1)->Apply({a}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{3, 2})); +} + +TEST_P(AutogradForwardTest, SliceForward) { + auto a = createTensor({4, 4}, 1.0f); + auto result = std::make_shared(std::vector{1, 1}, std::vector{3, 3}, + std::vector{1, 1}) + ->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, SplitForward) { + auto a = createTensor({4, 4}, 1.0f); + auto result = std::make_shared(2, 0)->Apply({a}); + EXPECT_EQ(result.size(), 2); +} + +TEST_P(AutogradForwardTest, ConcatForward) { + auto a = createTensor({2, 2}, 1.0f); + auto b = createTensor({2, 2}, 2.0f); + auto result = std::make_shared(0)->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{4, 2})); +} + +TEST_P(AutogradForwardTest, StackForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 2.0f); + auto result = std::make_shared(0)->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 2, 3})); +} + +TEST_P(AutogradForwardTest, TrilForward) { + auto a = createTensor({3, 3}, 1.0f); + auto result = std::make_shared(0)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, TriuForward) { + auto a = createTensor({3, 3}, 1.0f); + auto result = std::make_shared(0)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, OuterForward) { + auto a = createTensor({3}, 1.0f); + auto b = createTensor({4}, 1.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{3, 4})); +} + +TEST_P(AutogradForwardTest, AddScalarForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(2.0f)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, MulScalarForward) { + auto a = createTensor({2, 3}, 2.0f); + auto result = std::make_shared(3.0f)->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, LtForward) { + auto a = createTensor({2, 3}, 5.0f); + auto b = createTensor({2, 3}, 3.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, LeForward) { + auto a = createTensor({2, 3}, 3.0f); + auto b = createTensor({2, 3}, 3.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, GtForward) { + auto a = createTensor({2, 3}, 5.0f); + auto b = createTensor({2, 3}, 3.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, GeForward) { + auto a = createTensor({2, 3}, 3.0f); + auto b = createTensor({2, 3}, 3.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, EqualsForward) { + auto a = createTensor({2, 3}, 3.0f); + auto b = createTensor({2, 3}, 3.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, AndForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 1.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, OrForward) { + auto a = createTensor({2, 3}, 0.0f); + auto b = createTensor({2, 3}, 1.0f); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_P(AutogradForwardTest, NoOpForward) { + auto a = createTensor({2, 3}, 1.0f); + auto result = std::make_shared(std::vector{2, 3})->Apply({a}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 3})); +} + +TEST_P(AutogradBackwardTest, AddBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 2.0f); + auto add_fn = std::make_shared(); + auto result = add_fn->Apply({a, b}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = add_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +TEST_P(AutogradBackwardTest, MulBackward) { + auto a = createTensor({2, 3}, 2.0f); + auto b = createTensor({2, 3}, 3.0f); + auto mul_fn = std::make_shared(); + auto result = mul_fn->Apply({a, b}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = mul_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +INFINI_TRAIN_REGISTER_TEST(AutogradForwardTest); + +INFINI_TRAIN_REGISTER_TEST(AutogradBackwardTest); + +// ============================================================================ +// Distributed — requires NCCL + >=2 GPUs +// ============================================================================ + +class AutogradDistributedTest : public infini_train::test::DistributedInfiniTrainTestP {}; + +TEST_P(AutogradDistributedTest, AllReduce) { + auto a = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32, GetDevice()); + a->set_requires_grad(true); + infini_train::test::FillConstantTensor(a, 1.0f); + EXPECT_TRUE(a->GetDevice().IsCUDA()); + EXPECT_TRUE(a->requires_grad()); +} + +TEST_P(AutogradDistributedTest, AllGather) { + auto a = std::make_shared(std::vector{4, 4}, DataType::kFLOAT32, GetDevice()); + a->set_requires_grad(true); + infini_train::test::FillConstantTensor(a, 1.0f); + EXPECT_TRUE(a->GetDevice().IsCUDA()); + EXPECT_EQ(a->Dims(), (std::vector{4, 4})); +} + +TEST_P(AutogradDistributedTest, ReduceScatter) { + auto a = std::make_shared(std::vector{2, 8}, DataType::kFLOAT32, GetDevice()); + a->set_requires_grad(true); + infini_train::test::FillConstantTensor(a, 1.0f); + EXPECT_TRUE(a->GetDevice().IsCUDA()); + EXPECT_EQ(a->Dims(), (std::vector{2, 8})); +} + +TEST_P(AutogradDistributedTest, DistributedMatmul) { + auto a = std::make_shared(std::vector{2, 4}, DataType::kFLOAT32, GetDevice()); + a->set_requires_grad(true); + auto b = std::make_shared(std::vector{4, 2}, DataType::kFLOAT32, GetDevice()); + b->set_requires_grad(true); + auto result = std::make_shared()->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_TRUE(result[0]->GetDevice().IsCUDA()); +} + +TEST_P(AutogradDistributedTest, DistributedLinear) { + auto input = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32, GetDevice()); + input->set_requires_grad(true); + auto weight = std::make_shared(std::vector{4, 3}, DataType::kFLOAT32, GetDevice()); + weight->set_requires_grad(true); + auto bias = std::make_shared(std::vector{4}, DataType::kFLOAT32, GetDevice()); + bias->set_requires_grad(true); + auto result = std::make_shared()->Apply({input, weight, bias}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 4})); + EXPECT_TRUE(result[0]->GetDevice().IsCUDA()); +} + +INFINI_TRAIN_REGISTER_TEST_DISTRIBUTED(AutogradDistributedTest); diff --git a/tests/autograd/test_autograd_elementwise_backward.cc b/tests/autograd/test_autograd_elementwise_backward.cc new file mode 100644 index 00000000..502a20e5 --- /dev/null +++ b/tests/autograd/test_autograd_elementwise_backward.cc @@ -0,0 +1,134 @@ +#include + +#include +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/elementwise.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradElementwiseBackwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradElementwiseBackwardTest, AddBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 2.0f); + auto add_fn = std::make_shared(); + auto result = add_fn->Apply({a, b}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = add_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +TEST_F(AutogradElementwiseBackwardTest, SubBackward) { + auto a = createTensor({2, 3}, 5.0f); + auto b = createTensor({2, 3}, 3.0f); + auto sub_fn = std::make_shared(); + auto result = sub_fn->Apply({a, b}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = sub_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +TEST_F(AutogradElementwiseBackwardTest, MulBackward) { + auto a = createTensor({2, 3}, 2.0f); + auto b = createTensor({2, 3}, 3.0f); + auto mul_fn = std::make_shared(); + auto result = mul_fn->Apply({a, b}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = mul_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +TEST_F(AutogradElementwiseBackwardTest, DivBackward) { + auto a = createTensor({2, 3}, 6.0f); + auto b = createTensor({2, 3}, 2.0f); + auto div_fn = std::make_shared(); + auto result = div_fn->Apply({a, b}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = div_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +TEST_F(AutogradElementwiseBackwardTest, NegBackward) { + auto a = createTensor({2, 3}, 5.0f); + auto neg_fn = std::make_shared(); + auto result = neg_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = neg_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, SinBackward) { + auto a = createTensor({2, 3}, 0.0f); + auto sin_fn = std::make_shared(); + auto result = sin_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = sin_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, CosBackward) { + auto a = createTensor({2, 3}, 0.0f); + auto cos_fn = std::make_shared(); + auto result = cos_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = cos_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, TanhBackward) { + auto a = createTensor({2, 3}, 0.0f); + auto tanh_fn = std::make_shared(); + auto result = tanh_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = tanh_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, ExpBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto exp_fn = std::make_shared(); + auto result = exp_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = exp_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, LogBackward) { + auto a = createTensor({2, 3}, 2.0f); + auto log_fn = std::make_shared(); + auto result = log_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = log_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, ReciprocalBackward) { + auto a = createTensor({2, 3}, 2.0f); + auto reciprocal_fn = std::make_shared(); + auto result = reciprocal_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = reciprocal_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, PowBackward) { + auto a = createTensor({2, 3}, 2.0f); + auto pow_fn = std::make_shared(2.0f); + auto result = pow_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = pow_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradElementwiseBackwardTest, RsqrtBackward) { + auto a = createTensor({2, 3}, 4.0f); + auto rsqrt_fn = std::make_shared(); + auto result = rsqrt_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = rsqrt_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} diff --git a/tests/autograd/test_autograd_elementwise_forward.cc b/tests/autograd/test_autograd_elementwise_forward.cc new file mode 100644 index 00000000..63b386b1 --- /dev/null +++ b/tests/autograd/test_autograd_elementwise_forward.cc @@ -0,0 +1,187 @@ +#include + +#include +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/elementwise.h" +#include "infini_train/include/autograd/activations.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradElementwiseForwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradElementwiseForwardTest, AddForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 2.0f); + auto add_fn = std::make_shared(); + auto result = add_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 3})); +} + +TEST_F(AutogradElementwiseForwardTest, SubForward) { + auto a = createTensor({2, 3}, 5.0f); + auto b = createTensor({2, 3}, 3.0f); + auto sub_fn = std::make_shared(); + auto result = sub_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, MulForward) { + auto a = createTensor({2, 3}, 2.0f); + auto b = createTensor({2, 3}, 3.0f); + auto mul_fn = std::make_shared(); + auto result = mul_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, DivForward) { + auto a = createTensor({2, 3}, 6.0f); + auto b = createTensor({2, 3}, 2.0f); + auto div_fn = std::make_shared(); + auto result = div_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, NegForward) { + auto a = createTensor({2, 3}, 5.0f); + auto neg_fn = std::make_shared(); + auto result = neg_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, SinForward) { + auto a = createTensor({2, 3}, 0.0f); + auto sin_fn = std::make_shared(); + auto result = sin_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, CosForward) { + auto a = createTensor({2, 3}, 0.0f); + auto cos_fn = std::make_shared(); + auto result = cos_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, TanhForward) { + auto a = createTensor({2, 3}, 0.0f); + auto tanh_fn = std::make_shared(); + auto result = tanh_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, ExpForward) { + auto a = createTensor({2, 3}, 1.0f); + auto exp_fn = std::make_shared(); + auto result = exp_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, LogForward) { + auto a = createTensor({2, 3}, 2.0f); + auto log_fn = std::make_shared(); + auto result = log_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, ReciprocalForward) { + auto a = createTensor({2, 3}, 2.0f); + auto reciprocal_fn = std::make_shared(); + auto result = reciprocal_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, PowForward) { + auto a = createTensor({2, 3}, 2.0f); + auto pow_fn = std::make_shared(2.0f); + auto result = pow_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, RsqrtForward) { + auto a = createTensor({2, 3}, 4.0f); + auto rsqrt_fn = std::make_shared(); + auto result = rsqrt_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, SigmoidForward) { + auto a = createTensor({2, 3}, 0.0f); + auto sigmoid_fn = std::make_shared(); + auto result = sigmoid_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, AddScalarForward) { + auto a = createTensor({2, 3}, 1.0f); + auto add_scalar_fn = std::make_shared(2.0f); + auto result = add_scalar_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, MulScalarForward) { + auto a = createTensor({2, 3}, 2.0f); + auto mul_scalar_fn = std::make_shared(3.0f); + auto result = mul_scalar_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, LtForward) { + auto a = createTensor({2, 3}, 5.0f); + auto b = createTensor({2, 3}, 3.0f); + auto lt_fn = std::make_shared(); + auto result = lt_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, LeForward) { + auto a = createTensor({2, 3}, 3.0f); + auto b = createTensor({2, 3}, 3.0f); + auto le_fn = std::make_shared(); + auto result = le_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, GtForward) { + auto a = createTensor({2, 3}, 5.0f); + auto b = createTensor({2, 3}, 3.0f); + auto gt_fn = std::make_shared(); + auto result = gt_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, GeForward) { + auto a = createTensor({2, 3}, 3.0f); + auto b = createTensor({2, 3}, 3.0f); + auto ge_fn = std::make_shared(); + auto result = ge_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, EqualsForward) { + auto a = createTensor({2, 3}, 3.0f); + auto b = createTensor({2, 3}, 3.0f); + auto eq_fn = std::make_shared(); + auto result = eq_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, AndForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 1.0f); + auto and_fn = std::make_shared(); + auto result = and_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradElementwiseForwardTest, OrForward) { + auto a = createTensor({2, 3}, 0.0f); + auto b = createTensor({2, 3}, 1.0f); + auto or_fn = std::make_shared(); + auto result = or_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); +} diff --git a/tests/autograd/test_autograd_linear_backward.cc b/tests/autograd/test_autograd_linear_backward.cc new file mode 100644 index 00000000..069affc7 --- /dev/null +++ b/tests/autograd/test_autograd_linear_backward.cc @@ -0,0 +1,33 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/linear.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradLinearBackwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradLinearBackwardTest, LinearBackward) { + auto input = createTensor({2, 3}, 1.0f); + auto weight = createTensor({4, 3}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto linear_fn = std::make_shared(); + auto result = linear_fn->Apply({input, weight, bias}); + auto grad = createTensor({2, 4}, 1.0f); + auto grad_inputs = linear_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 3); +} + +TEST_F(AutogradLinearBackwardTest, LinearBackwardNoBias) { + auto input = createTensor({2, 3}, 1.0f); + auto weight = createTensor({4, 3}, 1.0f); + auto linear_fn = std::make_shared(); + auto result = linear_fn->Apply({input, weight}); + auto grad = createTensor({2, 4}, 1.0f); + auto grad_inputs = linear_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} diff --git a/tests/autograd/test_autograd_linear_forward.cc b/tests/autograd/test_autograd_linear_forward.cc new file mode 100644 index 00000000..efd8d6eb --- /dev/null +++ b/tests/autograd/test_autograd_linear_forward.cc @@ -0,0 +1,41 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/linear.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradLinearForwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradLinearForwardTest, LinearForward) { + auto input = createTensor({2, 3}, 1.0f); + auto weight = createTensor({4, 3}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto linear_fn = std::make_shared(); + auto result = linear_fn->Apply({input, weight, bias}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 4})); +} + +TEST_F(AutogradLinearForwardTest, LinearNoBias) { + auto input = createTensor({2, 3}, 1.0f); + auto weight = createTensor({4, 3}, 1.0f); + auto linear_fn = std::make_shared(); + auto result = linear_fn->Apply({input, weight}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 4})); +} + +TEST_F(AutogradLinearForwardTest, LinearBatch) { + auto input = createTensor({32, 128}, 1.0f); + auto weight = createTensor({64, 128}, 1.0f); + auto bias = createTensor({64}, 0.0f); + auto linear_fn = std::make_shared(); + auto result = linear_fn->Apply({input, weight, bias}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{32, 64})); +} diff --git a/tests/autograd/test_autograd_matmul_backward.cc b/tests/autograd/test_autograd_matmul_backward.cc new file mode 100644 index 00000000..e9962f5d --- /dev/null +++ b/tests/autograd/test_autograd_matmul_backward.cc @@ -0,0 +1,42 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/matmul.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradMatmulBackwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradMatmulBackwardTest, MatmulBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({3, 4}, 1.0f); + auto matmul_fn = std::make_shared(); + auto result = matmul_fn->Apply({a, b}); + auto grad = createTensor({2, 4}, 1.0f); + auto grad_inputs = matmul_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +TEST_F(AutogradMatmulBackwardTest, MatmulBackwardSquare) { + auto a = createTensor({3, 3}, 2.0f); + auto b = createTensor({3, 3}, 3.0f); + auto matmul_fn = std::make_shared(); + auto result = matmul_fn->Apply({a, b}); + auto grad = createTensor({3, 3}, 1.0f); + auto grad_inputs = matmul_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} + +TEST_F(AutogradMatmulBackwardTest, MatmulBackwardDifferentShapes) { + auto a = createTensor({3, 4}, 1.5f); + auto b = createTensor({4, 2}, 2.5f); + auto matmul_fn = std::make_shared(); + auto result = matmul_fn->Apply({a, b}); + auto grad = createTensor({3, 2}, 1.0f); + auto grad_inputs = matmul_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 2); +} diff --git a/tests/autograd/test_autograd_matmul_forward.cc b/tests/autograd/test_autograd_matmul_forward.cc new file mode 100644 index 00000000..87c93f08 --- /dev/null +++ b/tests/autograd/test_autograd_matmul_forward.cc @@ -0,0 +1,48 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/matmul.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradMatmulForwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradMatmulForwardTest, MatmulForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({3, 4}, 1.0f); + auto matmul_fn = std::make_shared(); + auto result = matmul_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 4})); +} + +TEST_F(AutogradMatmulForwardTest, MatmulDifferentShapes) { + auto a = createTensor({3, 4}, 1.0f); + auto b = createTensor({4, 2}, 1.0f); + auto matmul_fn = std::make_shared(); + auto result = matmul_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{3, 2})); +} + +TEST_F(AutogradMatmulForwardTest, MatmulBatch) { + auto a = createTensor({2, 3, 4}, 1.0f); + auto b = createTensor({2, 4, 5}, 1.0f); + auto matmul_fn = std::make_shared(); + auto result = matmul_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 3, 5})); +} + +TEST_F(AutogradMatmulForwardTest, MatmulSquare) { + auto a = createTensor({3, 3}, 1.0f); + auto b = createTensor({3, 3}, 1.0f); + auto matmul_fn = std::make_shared(); + auto result = matmul_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{3, 3})); +} diff --git a/tests/autograd/test_autograd_normalization_backward.cc b/tests/autograd/test_autograd_normalization_backward.cc new file mode 100644 index 00000000..6f97349e --- /dev/null +++ b/tests/autograd/test_autograd_normalization_backward.cc @@ -0,0 +1,34 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/normalization.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradNormalizationBackwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradNormalizationBackwardTest, LayerNormBackward) { + auto a = createTensor({2, 3, 4}, 1.0f); + auto weight = createTensor({4}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto layernorm_fn = std::make_shared(1e-5f); + auto result = layernorm_fn->Apply({a, weight, bias}); + auto grad = createTensor({2, 3, 4}, 1.0f); + auto grad_inputs = layernorm_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 3); +} + +TEST_F(AutogradNormalizationBackwardTest, LayerNormBackwardZeroBias) { + auto a = createTensor({2, 3, 4}, 1.0f); + auto weight = createTensor({4}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto layernorm_fn = std::make_shared(1e-5f); + auto result = layernorm_fn->Apply({a, weight, bias}); + auto grad = createTensor({2, 3, 4}, 1.0f); + auto grad_inputs = layernorm_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 3); +} diff --git a/tests/autograd/test_autograd_normalization_forward.cc b/tests/autograd/test_autograd_normalization_forward.cc new file mode 100644 index 00000000..d58fd749 --- /dev/null +++ b/tests/autograd/test_autograd_normalization_forward.cc @@ -0,0 +1,40 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/normalization.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradNormalizationForwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradNormalizationForwardTest, LayerNormForward) { + auto a = createTensor({2, 3, 4}, 1.0f); + auto weight = createTensor({4}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto layernorm_fn = std::make_shared(1e-5f); + auto result = layernorm_fn->Apply({a, weight, bias}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradNormalizationForwardTest, LayerNormZeroBias) { + auto a = createTensor({2, 3, 4}, 1.0f); + auto weight = createTensor({4}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto layernorm_fn = std::make_shared(1e-5f); + auto result = layernorm_fn->Apply({a, weight, bias}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradNormalizationForwardTest, LayerNormThreeDim) { + auto a = createTensor({2, 1, 4}, 1.0f); + auto weight = createTensor({4}, 1.0f); + auto bias = createTensor({4}, 0.0f); + auto layernorm_fn = std::make_shared(1e-5f); + auto result = layernorm_fn->Apply({a, weight, bias}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 1, 4})); +} diff --git a/tests/autograd/test_autograd_reduction_backward.cc b/tests/autograd/test_autograd_reduction_backward.cc new file mode 100644 index 00000000..d212a065 --- /dev/null +++ b/tests/autograd/test_autograd_reduction_backward.cc @@ -0,0 +1,66 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/reduction.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradReductionBackwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradReductionBackwardTest, SumBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto sum_fn = std::make_shared(1, false); + auto result = sum_fn->Apply({a}); + auto grad = createTensor({2}, 1.0f); + auto grad_inputs = sum_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradReductionBackwardTest, MeanBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto mean_fn = std::make_shared(1, false); + auto result = mean_fn->Apply({a}); + auto grad = createTensor({2}, 1.0f); + auto grad_inputs = mean_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradReductionBackwardTest, MaxBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto max_fn = std::make_shared(1, false); + auto result = max_fn->Apply({a}); + auto grad = createTensor({2}, 1.0f); + auto grad_inputs = max_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradReductionBackwardTest, MinBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto min_fn = std::make_shared(1, false); + auto result = min_fn->Apply({a}); + auto grad = createTensor({2}, 1.0f); + auto grad_inputs = min_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradReductionBackwardTest, SumBackwardKeepDim) { + auto a = createTensor({2, 3}, 1.0f); + auto sum_fn = std::make_shared(1, true); + auto result = sum_fn->Apply({a}); + auto grad = createTensor({2, 1}, 1.0f); + auto grad_inputs = sum_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradReductionBackwardTest, MeanBackwardKeepDim) { + auto a = createTensor({2, 3}, 1.0f); + auto mean_fn = std::make_shared(1, true); + auto result = mean_fn->Apply({a}); + auto grad = createTensor({2, 1}, 1.0f); + auto grad_inputs = mean_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} diff --git a/tests/autograd/test_autograd_reduction_forward.cc b/tests/autograd/test_autograd_reduction_forward.cc new file mode 100644 index 00000000..b4f8edb7 --- /dev/null +++ b/tests/autograd/test_autograd_reduction_forward.cc @@ -0,0 +1,54 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/reduction.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradReductionForwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradReductionForwardTest, SumForward) { + auto a = createTensor({2, 3}, 1.0f); + auto sum_fn = std::make_shared(1, false); + auto result = sum_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradReductionForwardTest, MeanForward) { + auto a = createTensor({2, 3}, 1.0f); + auto mean_fn = std::make_shared(1, false); + auto result = mean_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradReductionForwardTest, MaxForward) { + auto a = createTensor({2, 3}, 1.0f); + auto max_fn = std::make_shared(1, false); + auto result = max_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradReductionForwardTest, MinForward) { + auto a = createTensor({2, 3}, 1.0f); + auto min_fn = std::make_shared(1, false); + auto result = min_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradReductionForwardTest, SumKeepDim) { + auto a = createTensor({2, 3}, 1.0f); + auto sum_fn = std::make_shared(1, true); + auto result = sum_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradReductionForwardTest, MeanKeepDim) { + auto a = createTensor({2, 3}, 1.0f); + auto mean_fn = std::make_shared(1, true); + auto result = mean_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} diff --git a/tests/autograd/test_autograd_softmax_backward.cc b/tests/autograd/test_autograd_softmax_backward.cc new file mode 100644 index 00000000..6d3f02a4 --- /dev/null +++ b/tests/autograd/test_autograd_softmax_backward.cc @@ -0,0 +1,30 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/softmax.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradSoftmaxBackwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradSoftmaxBackwardTest, SoftmaxBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto softmax_fn = std::make_shared(1); + auto result = softmax_fn->Apply({a}); + auto grad = createTensor({2, 3}, 1.0f); + auto grad_inputs = softmax_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} + +TEST_F(AutogradSoftmaxBackwardTest, SoftmaxBackwardDim0) { + auto a = createTensor({4, 3}, 1.0f); + auto softmax_fn = std::make_shared(0); + auto result = softmax_fn->Apply({a}); + auto grad = createTensor({4, 3}, 1.0f); + auto grad_inputs = softmax_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} diff --git a/tests/autograd/test_autograd_softmax_forward.cc b/tests/autograd/test_autograd_softmax_forward.cc new file mode 100644 index 00000000..c3d196f1 --- /dev/null +++ b/tests/autograd/test_autograd_softmax_forward.cc @@ -0,0 +1,36 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/softmax.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradSoftmaxForwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradSoftmaxForwardTest, SoftmaxForward) { + auto a = createTensor({2, 3}, 1.0f); + auto softmax_fn = std::make_shared(1); + auto result = softmax_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 3})); +} + +TEST_F(AutogradSoftmaxForwardTest, SoftmaxDim0) { + auto a = createTensor({4, 3}, 1.0f); + auto softmax_fn = std::make_shared(0); + auto result = softmax_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{4, 3})); +} + +TEST_F(AutogradSoftmaxForwardTest, SoftmaxLastDim) { + auto a = createTensor({2, 3, 4}, 1.0f); + auto softmax_fn = std::make_shared(2); + auto result = softmax_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 3, 4})); +} diff --git a/tests/autograd/test_autograd_transform_backward.cc b/tests/autograd/test_autograd_transform_backward.cc new file mode 100644 index 00000000..1613f1a2 --- /dev/null +++ b/tests/autograd/test_autograd_transform_backward.cc @@ -0,0 +1,21 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/transform.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradTransformBackwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradTransformBackwardTest, TransposeBackward) { + auto a = createTensor({2, 3}, 1.0f); + auto transpose_fn = std::make_shared(0, 1); + auto result = transpose_fn->Apply({a}); + auto grad = createTensor({3, 2}, 1.0f); + auto grad_inputs = transpose_fn->Backward({grad}); + EXPECT_EQ(grad_inputs.size(), 1); +} diff --git a/tests/autograd/test_autograd_transform_forward.cc b/tests/autograd/test_autograd_transform_forward.cc new file mode 100644 index 00000000..67b20adb --- /dev/null +++ b/tests/autograd/test_autograd_transform_forward.cc @@ -0,0 +1,70 @@ +#include + +#include + +#include "infini_train/include/tensor.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/autograd/transform.h" +#include "infini_train/include/autograd/misc.h" +#include "test_utils.h" + +using namespace infini_train; + +class AutogradTransformForwardTest : public infini_train::test::AutogradTestBase {}; + +TEST_F(AutogradTransformForwardTest, TransposeForward) { + auto a = createTensor({2, 3}, 1.0f); + auto transpose_fn = std::make_shared(0, 1); + auto result = transpose_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{3, 2})); +} + +TEST_F(AutogradTransformForwardTest, SliceForward) { + auto a = createTensor({4, 4}, 1.0f); + auto slice_fn = std::make_shared( + std::vector{1, 1}, + std::vector{3, 3}, + std::vector{1, 1}); + auto result = slice_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradTransformForwardTest, SplitForward) { + auto a = createTensor({4, 4}, 1.0f); + auto split_fn = std::make_shared(2, 0); + auto result = split_fn->Apply({a}); + EXPECT_EQ(result.size(), 2); +} + +TEST_F(AutogradTransformForwardTest, ConcatForward) { + auto a = createTensor({2, 2}, 1.0f); + auto b = createTensor({2, 2}, 2.0f); + auto concat_fn = std::make_shared(0); + auto result = concat_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{4, 2})); +} + +TEST_F(AutogradTransformForwardTest, StackForward) { + auto a = createTensor({2, 3}, 1.0f); + auto b = createTensor({2, 3}, 2.0f); + auto stack_fn = std::make_shared(0); + auto result = stack_fn->Apply({a, b}); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0]->Dims(), (std::vector{2, 2, 3})); +} + +TEST_F(AutogradTransformForwardTest, TrilForward) { + auto a = createTensor({3, 3}, 1.0f); + auto tril_fn = std::make_shared(0); + auto result = tril_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} + +TEST_F(AutogradTransformForwardTest, TriuForward) { + auto a = createTensor({3, 3}, 1.0f); + auto triu_fn = std::make_shared(0); + auto result = triu_fn->Apply({a}); + EXPECT_EQ(result.size(), 1); +} diff --git a/tests/common/CMakeLists.txt b/tests/common/CMakeLists.txt new file mode 100644 index 00000000..3960d474 --- /dev/null +++ b/tests/common/CMakeLists.txt @@ -0,0 +1,4 @@ +# Common test utilities + +add_library(test_utils INTERFACE) +target_include_directories(test_utils INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/tests/common/test_macros.cmake b/tests/common/test_macros.cmake new file mode 100644 index 00000000..477a668d --- /dev/null +++ b/tests/common/test_macros.cmake @@ -0,0 +1,141 @@ +# ============================================================================ +# InfiniTrain Test Macros +# ============================================================================ +# Unified test configuration interface to reduce boilerplate. +# +# Usage: +# 1. Include this file in tests/CMakeLists.txt +# 2. Use infini_train_add_test macro to register tests +# +# Examples: +# infini_train_add_test( +# test_tensor_create +# SOURCES test_tensor_create.cc +# LABELS cpu cuda +# ) +# ============================================================================ + +include_guard(GLOBAL) + +# Path to this file's directory (tests/common/) +set(TEST_MACROS_DIR "${CMAKE_CURRENT_LIST_DIR}") + +# ----------------------------------------------------------------------------- +# Load GoogleTest module (provides gtest_discover_tests) +# ----------------------------------------------------------------------------- +include(GoogleTest) + +# ----------------------------------------------------------------------------- +# infini_train_add_test - Test registration macro +# ----------------------------------------------------------------------------- +# Features: +# 1. Create executable target +# 2. Configure compile options, link libraries, and include paths +# 3. Use gtest_discover_tests to auto-discover test cases +# 4. Set test labels +# +# Arguments: +# SOURCES: Source file list (required) +# LABELS: Test labels, e.g. "cpu" "cuda" "distributed" (optional, default "cpu") +# TEST_FILTER: gtest test filter pattern (optional) +# +# Examples: +# # Single-label test (one liner) +# infini_train_add_test(test_example SOURCES test_example.cc LABELS cpu) +# +# # Filter same binary by label suffix (one call per label) +# infini_train_add_test(test_example SOURCES test_example.cc LABELS cpu TEST_FILTER "-*CUDA*") +# infini_train_add_test(test_example_cuda SOURCES test_example.cc LABELS cuda TEST_FILTER "*CUDA*") +# ----------------------------------------------------------------------------- +macro(infini_train_add_test) + cmake_parse_arguments(ARG "" "TEST_NAME;TEST_FILTER" "SOURCES;LABELS" ${ARGN}) + + if(NOT ARG_TEST_NAME) + set(ARG_TEST_NAME ${ARG_UNPARSED_ARGUMENTS}) + endif() + + if(NOT ARG_SOURCES) + message(FATAL_ERROR "infini_train_add_test: TEST_NAME and SOURCES are required") + endif() + + # 1. Create executable target + add_executable(${ARG_TEST_NAME} ${ARG_SOURCES}) + + # 2. Disable -Werror so tests can run under relaxed warning levels + target_compile_options(${ARG_TEST_NAME} PRIVATE -Wno-error) + + # 3. Link Google Test + target_link_libraries(${ARG_TEST_NAME} PRIVATE + GTest::gtest + GTest::gtest_main + ) + + # 4. Add include paths + target_include_directories(${ARG_TEST_NAME} PRIVATE + ${TEST_MACROS_DIR} + ${glog_SOURCE_DIR}/src + ) + + # 5. Link project library (reuses framework linking strategy) + link_infini_train_exe(${ARG_TEST_NAME}) + + # 6. Auto-discover gtest cases and register as ctest tests + set(labels "cpu") + if(ARG_LABELS) + set(labels "${ARG_LABELS}") + endif() + + if(ARG_TEST_FILTER) + gtest_discover_tests(${ARG_TEST_NAME} + EXTRA_ARGS --gtest_output=xml:%T.xml + TEST_FILTER "${ARG_TEST_FILTER}" + PROPERTIES LABELS "${labels}" + ) + else() + gtest_discover_tests(${ARG_TEST_NAME} + EXTRA_ARGS --gtest_output=xml:%T.xml + PROPERTIES LABELS "${labels}" + ) + endif() +endmacro() + +# ----------------------------------------------------------------------------- +# infini_train_add_test_suite - Register cpu/cuda/distributed targets in one call +# ----------------------------------------------------------------------------- +# Calls infini_train_add_test three times (or fewer) with the correct +# TEST_FILTER and LABELS derived from the label list. +# +# Arguments: +# Base name; each target is named _