Skip to content

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142

Open
chen2021673 wants to merge 3 commits intomasterfrom
split_linear_backward
Open

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
chen2021673 wants to merge 3 commits intomasterfrom
split_linear_backward

Conversation

@chen2021673
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 commented Apr 10, 2026

概述

完成了 Linear/Matmul/Outer kernel 的架构重构:

核心思路是关注点分离——把是否应当计算的决策从 kernel 层上移到 autograd 层,让 kernel 成为纯计算函数;同时在底层抽象出统一的 GEMM 原语,消除重复的 cuBLAS 调用样板。

具体改动

  • autograd 层:LinearBackward 和 MatmulBackward 各自拆解为多个独立的 Dispatcher 调用,needs_input_grad 检查在 autograd 层完成,只调用实际需要的 kernel。
  • kernel 层:原来的 LinearBackward(单体)→ LinearBackwardInput / LinearBackwardWeight / LinearBackwardBias 三个独立 kernel;MatmulBackward(单体)→ MatmulBackwardInput / MatmulBackwardOther 两个独立 kernel,命名与 MatmulForward(input, other) 对齐。
  • 文件拆分:Matmul kernel 从 linear.cc / linear.cu 中独立出来,建立 cpu/matmul.cc 和 cuda/matmul.cu,各文件职责单一。
  • GEMM 原语:新增 gemm.cuh / gemm.cu,定义 GemmParams 结构体和 GemmCuda(),统一封装 cublasGemmEx 和 cublasGemmStridedBatchedEx 的分支逻辑。GetCublasHandle() / GetCudaStream() 集中定义,供 linear.cu、matmul.cu、outer.cu 共享,消除各文件中的重复定义。
  • outer.cu:同步改造,使用 GemmCuda() 替换 cuBLAS 矩阵乘法调用,仅保留 fp32 反向路径的 cublasSgemv(矩阵×向量形状下更高效,且 bf16 不支持 GEMV)。

Move grad_flags logic from kernel to autograd layer. The
monolithic LinearBackward kernel is replaced by LinearBackwardInput,
LinearBackwardWeight, and LinearBackwardBias — each a pure compute
operation with no autograd-related parameters.
@chen2021673 chen2021673 force-pushed the split_linear_backward branch from 8f64209 to 66e45dc Compare April 10, 2026 08:18
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel
is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
@chen2021673 chen2021673 force-pushed the split_linear_backward branch from 66e45dc to be6eed3 Compare April 10, 2026 08:23
…ls; rename MatmulBackwardInput1/2

- Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or
  cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream()
  shared across all GEMM-using kernels
- Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated
  matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels
- Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther
  for semantic clarity matching MatmulForward(input, other) parameter names
- Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths);
  keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
@chen2021673 chen2021673 force-pushed the split_linear_backward branch from 283d083 to 23d301b Compare April 15, 2026 01:58
@chen2021673 chen2021673 requested a review from kilinchange April 15, 2026 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant