Skip to content

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142

Open
chen2021673 wants to merge 5 commits intomasterfrom
split_linear_backward
Open

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
chen2021673 wants to merge 5 commits intomasterfrom
split_linear_backward

Conversation

@chen2021673
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 commented Apr 10, 2026

Summary

Architecture refactoring of Linear/Matmul/Outer kernels.

The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.

Changes

  • Autograd layer: LinearBackward and MatmulBackward are each decomposed into multiple independent Dispatcher calls. The needs_input_grad checks happen at the autograd layer, invoking only the kernels actually needed.
  • Kernel layer: The monolithic LinearBackward is split into LinearBackwardInput / LinearBackwardWeight / LinearBackwardBias; MatmulBackward is split into MatmulBackwardInput / MatmulBackwardOther, with naming aligned to MatmulForward(input, other).
  • File split: Matmul kernels are extracted from linear.cc / linear.cu into dedicated cpu/matmul.cc and cuda/matmul.cu, giving each file a single responsibility.
  • GEMM primitive: New gemm.cuh / gemm.cu define the GemmParams struct and GemmCuda(), providing a unified wrapper over cublasGemmEx and cublasGemmStridedBatchedEx branching logic. GetCublasHandle() / GetCudaStream() are centrally defined and shared across linear.cu, matmul.cu, and outer.cu, eliminating duplicate definitions.
  • SGEMV primitive: New SgemvParams struct and SgemvCuda() wrap the cublasSgemv call. LinearForward and LinearBackwardInput in linear.cu take the SGEMV path when bs==1 and fp32 (more efficient for matrix-vector shapes); bf16 falls back to GemmCuda since cublasSgemv does not support it. The fp32 backward path in outer.cu is migrated to SgemvCuda as well, eliminating inline cublasSgemv calls.

@chen2021673 chen2021673 force-pushed the split_linear_backward branch 3 times, most recently from 283d083 to 23d301b Compare April 15, 2026 01:58
@chen2021673 chen2021673 requested a review from kilinchange April 15, 2026 02:08
Move grad_flags logic from kernel to autograd layer. The
monolithic LinearBackward kernel is replaced by LinearBackwardInput,
LinearBackwardWeight, and LinearBackwardBias — each a pure compute
operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel
is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2

- Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or
  cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream()
  shared across all GEMM-using kernels
- Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated
  matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels
- Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther
  for semantic clarity matching MatmulForward(input, other) parameter names
- Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths);
  keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
@chen2021673 chen2021673 force-pushed the split_linear_backward branch 2 times, most recently from ae80cec to 88579ba Compare April 28, 2026 09:06
@chen2021673 chen2021673 force-pushed the split_linear_backward branch from 88579ba to 252e6cd Compare April 28, 2026 09:21
@Chamberlain0w0
Copy link
Copy Markdown
Contributor

另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题

Comment thread infini_train/src/kernels/cuda/linear.cu Outdated
// C = output.T[out_features, bs]
// A = weight.T[out_features, in_features]
// B = input.T[in_features, bs]
GemmParams p;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块感觉改成 C++20 designated initializer 的形式,会稍微明确且好看一点(但可能得注意下每个成员的顺序)。前面后面也是一样。

GemmParams p{
     .trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
     .trans_b = CUBLAS_OP_N;

     .m = static_cast<int>(out_features);
     .n = static_cast<int>(bs);
     .k = static_cast<int>(in_features);

    ...
};

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
};

saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul 的 saved_tensors_ 里面,现在是有可能存 nullptr 的,但问题是后面的 Backward 仍然是从 saved_tensors_ 读取 input1/input2 然后做操作,原则上有空指针 seg fault 的风险

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread infini_train/src/autograd/matmul.cc Outdated
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto device = input1->GetDevice().type();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

比如这里,并没有判断 input1 是否为 nullptr,默认是有值的

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续 ctest PR 合了之后我加一下针对这个 case 的单测

DataType input_dtype; // dtype of A and B
DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out)

cublasHandle_t blas_handle = nullptr;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉 handle 不作为成员变量存在这个结构体里比较好, GemmParams 应该只是一个问题描述,不应该涉及任何运行时的东西,这样的话 GemmCuda 的接口可以改成 GemmCuda(device, p),然后 GetHandle 的行为也都可以统一到 GemmCuda 里面,就不需要再包个全局的 GetCublasHandle() 的 helper 了

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实我感觉 GetCudaStream 的 helper 也怪怪的,如果有这种 helper 为啥不共享给全部 kernel 用,放这只给 gemm 的几个😂

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nsdd,调用使用哪个 cuBLAS handle,绑定哪个 CUDA stream,在哪个 device 上执行应该是 kernel launcher / runtime 层应该负责的事情。我改一下。
GetCudaStream 感觉 要么放gemm_cuda.cu 匿名 namespace 只给这里用,要么放到统一 CUDA runtime 头文件里做公共函数?

…s to designated initializers

- Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims()
  calls on potentially-null saved tensors in Backward
- Get device from grad_output instead of input1 in Matmul::Backward
- Add CHECK guards before dereferencing nullable saved tensors
- Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu,
  outer.cu to C++20 designated initializer form
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants