Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
Refactor(linear): split LinearBackward kernel into 3 independent kernels#142chen2021673 wants to merge 5 commits intomasterfrom
Conversation
283d083 to
23d301b
Compare
Move grad_flags logic from kernel to autograd layer. The monolithic LinearBackward kernel is replaced by LinearBackwardInput, LinearBackwardWeight, and LinearBackwardBias — each a pure compute operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2 - Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream() shared across all GEMM-using kernels - Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels - Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther for semantic clarity matching MatmulForward(input, other) parameter names - Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths); keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
ae80cec to
88579ba
Compare
…es in linear kernels
88579ba to
252e6cd
Compare
|
另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题 |
| // C = output.T[out_features, bs] | ||
| // A = weight.T[out_features, in_features] | ||
| // B = input.T[in_features, bs] | ||
| GemmParams p; |
There was a problem hiding this comment.
这块感觉改成 C++20 designated initializer 的形式,会稍微明确且好看一点(但可能得注意下每个成员的顺序)。前面后面也是一样。
GemmParams p{
.trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
.trans_b = CUBLAS_OP_N;
.m = static_cast<int>(out_features);
.n = static_cast<int>(bs);
.k = static_cast<int>(in_features);
...
};| return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype)); | ||
| }; | ||
|
|
||
| saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr}; |
There was a problem hiding this comment.
matmul 的 saved_tensors_ 里面,现在是有可能存 nullptr 的,但问题是后面的 Backward 仍然是从 saved_tensors_ 读取 input1/input2 然后做操作,原则上有空指针 seg fault 的风险
| bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0]; | ||
| bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; | ||
|
|
||
| auto device = input1->GetDevice().type(); |
There was a problem hiding this comment.
比如这里,并没有判断 input1 是否为 nullptr,默认是有值的
There was a problem hiding this comment.
后续 ctest PR 合了之后我加一下针对这个 case 的单测
| DataType input_dtype; // dtype of A and B | ||
| DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out) | ||
|
|
||
| cublasHandle_t blas_handle = nullptr; |
There was a problem hiding this comment.
我感觉 handle 不作为成员变量存在这个结构体里比较好, GemmParams 应该只是一个问题描述,不应该涉及任何运行时的东西,这样的话 GemmCuda 的接口可以改成 GemmCuda(device, p),然后 GetHandle 的行为也都可以统一到 GemmCuda 里面,就不需要再包个全局的 GetCublasHandle() 的 helper 了
There was a problem hiding this comment.
其实我感觉 GetCudaStream 的 helper 也怪怪的,如果有这种 helper 为啥不共享给全部 kernel 用,放这只给 gemm 的几个😂
There was a problem hiding this comment.
nsdd,调用使用哪个 cuBLAS handle,绑定哪个 CUDA stream,在哪个 device 上执行应该是 kernel launcher / runtime 层应该负责的事情。我改一下。
GetCudaStream 感觉 要么放gemm_cuda.cu 匿名 namespace 只给这里用,要么放到统一 CUDA runtime 头文件里做公共函数?
…s to designated initializers - Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims() calls on potentially-null saved tensors in Backward - Get device from grad_output instead of input1 in Matmul::Backward - Add CHECK guards before dereferencing nullable saved tensors - Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu, outer.cu to C++20 designated initializer form
Summary
Architecture refactoring of Linear/Matmul/Outer kernels.
The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.
Changes