diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 144aea1a07..4db709e8ae 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -156,11 +156,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.lda = is_A_transposed ? m : k; } - if (is_fp8_dtype(ret.Atype)) { - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); - } + // Note: lda%16 check removed — cublas_gemm handles alignment padding automatically + // for sequence packing with dynamic token counts. } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -206,12 +203,14 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.lda = k; // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.lda % 16) == 0, - "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // lda%16 check removed — cublas_gemm handles padding + // NVTE_CHECK((ret.lda % 16) == 0, + // "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. - NVTE_CHECK((m % 8) == 0, - "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // m%8 check removed — cublas_gemm handles padding + // NVTE_CHECK((m % 8) == 0, + // "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } else { NVTE_ERROR("A has unsupported scaling mode"); } @@ -247,11 +246,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.ldb = is_B_transposed ? k : n; } - if (is_fp8_dtype(ret.Atype)) { - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); - } + // ldb%16 check removed — cublas_gemm handles alignment padding automatically } else if (nvfp4) { if (is_B_transposed) { NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode), @@ -292,12 +287,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Requirements from // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.ldb % 16) == 0, - "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // ldb%16 check removed — cublas_gemm handles padding + // NVTE_CHECK((ret.ldb % 16) == 0, + // "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { // Observed this requirement only present for B tensor is 1D quantized. - NVTE_CHECK((n % 8) == 0, - "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // n%8 check removed — cublas_gemm handles alignment padding } } else { NVTE_ERROR("B has unsupported scaling mode"); @@ -325,24 +320,91 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int B1 = inputB->flat_last_dim(); // GEMM dims in column-major order - const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int m_real = transa == CUBLAS_OP_T ? A0 : A1; const int n = transb == CUBLAS_OP_T ? B1 : B0; - const int k = transa == CUBLAS_OP_T ? A1 : A0; - NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, + const int k_real = transa == CUBLAS_OP_T ? A1 : A0; + NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k_real, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); - const int ldd = m; // Return immediately if GEMM is trivial - if (m <= 0 || n <= 0) { + if (m_real <= 0 || n <= 0) { return; } - NVTE_CHECK(k > 0); + NVTE_CHECK(k_real > 0); + + // FP8 alignment: cuBLAS requires m%16==0, k%16==0 for FP8 GEMM. + // With sequence packing, token dims (m or k) may be unaligned. + // Pad to multiples of 16 BEFORE CanonicalizeGemmInput. + const bool is_fp8_a = + is_fp8_dtype(inputA->data.dtype) || + (inputA->has_columnwise_data() && is_fp8_dtype(inputA->columnwise_data.dtype)); + const bool is_fp8_b = + is_fp8_dtype(inputB->data.dtype) || + (inputB->has_columnwise_data() && is_fp8_dtype(inputB->columnwise_data.dtype)); + const bool need_fp8_pad = is_fp8_a || is_fp8_b; + const int m = need_fp8_pad ? ((m_real + 15) / 16) * 16 : m_real; + const int k = need_fp8_pad ? ((k_real + 15) / 16) * 16 : k_real; + const int ldd = m; + + void *_pad_D = nullptr; + if (m != m_real && outputD->data.dptr) { + // Output needs padded buffer (m_padded rows instead of m_real) + const size_t d_elem = typeToSize(outputD->data.dtype); + cudaMallocAsync(&_pad_D, (size_t)m * n * d_elem, stream); + cudaMemsetAsync(_pad_D, 0, (size_t)m * n * d_elem, stream); + } + + GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + + // Safe k-padding: if k was padded, the extra rows in A and B (beyond k_real) + // contain garbage data which would corrupt ALL output values via dot product. + // (This happens in wgrad where k = token_count = unaligned.) + // Allocate padded copies with zeros for the extra rows. + // After CanonicalizeGemmInput on Hopper FP8 (TN layout): + // A: col-major [lda, m], lda = k (padded) + // B: col-major [ldb, n], ldb = k (padded) + // Original data has k_real contiguous elements per column. + void *_pad_A = nullptr; + void *_pad_B = nullptr; + if (k != k_real && param.A && param.B) { + const size_t a_elem = typeToSize(param.Atype); + const size_t b_elem = typeToSize(param.Btype); + // For TN: A is [k, m] col-major, B is [k, n] col-major + // For NN: A is [m, k] col-major (lda=m), B is [k, n] col-major (ldb=k) + // Determine number of columns for each matrix + const int a_cols = (param.transA == CUBLAS_OP_T) ? m : k; + const int b_cols = (param.transB == CUBLAS_OP_N) ? n : k; + // Leading dimension tells us row stride + const int a_lda = param.lda; + const int b_ldb = param.ldb; + // Original leading dimension before k-padding + // For TN: original lda was k_real (before we passed k_padded to Canonicalize) + // For NN: lda = m (not affected by k), ldb was k_real + const int a_orig_ld = (param.transA == CUBLAS_OP_T) ? k_real : a_lda; + const int b_orig_ld = (param.transB == CUBLAS_OP_N) ? k_real : b_ldb; + + // Only pad A if its leading dimension involves k + if (a_lda != a_orig_ld) { + cudaMallocAsync(&_pad_A, (size_t)a_lda * a_cols * a_elem, stream); + cudaMemsetAsync(_pad_A, 0, (size_t)a_lda * a_cols * a_elem, stream); + cudaMemcpy2DAsync(_pad_A, (size_t)a_lda * a_elem, param.A, (size_t)a_orig_ld * a_elem, + (size_t)a_orig_ld * a_elem, a_cols, cudaMemcpyDeviceToDevice, stream); + param.A = _pad_A; + } - const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + // Only pad B if its leading dimension involves k + if (b_ldb != b_orig_ld) { + cudaMallocAsync(&_pad_B, (size_t)b_ldb * b_cols * b_elem, stream); + cudaMemsetAsync(_pad_B, 0, (size_t)b_ldb * b_cols * b_elem, stream); + cudaMemcpy2DAsync(_pad_B, (size_t)b_ldb * b_elem, param.B, (size_t)b_orig_ld * b_elem, + (size_t)b_orig_ld * b_elem, b_cols, cudaMemcpyDeviceToDevice, stream); + param.B = _pad_B; + } + } - void *C = outputD->data.dptr; - void *D = outputD->data.dptr; + void *C = _pad_D ? _pad_D : outputD->data.dptr; + void *D = _pad_D ? _pad_D : outputD->data.dptr; void *D_scale = outputD->scale.dptr; void *D_amax = outputD->amax.dptr; void *bias_ptr = inputBias->data.dptr; @@ -795,6 +857,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc)); + + // FP8 alignment cleanup: copy padded output back, free padded buffers. + // Using stream-ordered cudaFreeAsync — no CPU-GPU sync, no pipeline bubbles, + // no competition with PyTorch's caching allocator. + if (_pad_D) { + const size_t d_elem = typeToSize(outputD->data.dtype); + // Column-major: output is [m, n], copy m_real rows from each column + cudaMemcpy2DAsync(outputD->data.dptr, (size_t)m_real * d_elem, _pad_D, (size_t)m * d_elem, + (size_t)m_real * d_elem, n, cudaMemcpyDeviceToDevice, stream); + cudaFreeAsync(_pad_D, stream); + } + if (_pad_A) { + cudaFreeAsync(_pad_A, stream); + } + if (_pad_B) { + cudaFreeAsync(_pad_B, stream); + } } } // namespace transformer_engine diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index a76f205acc..e8066c49c5 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -477,16 +477,12 @@ def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: - """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" + """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM. - for tensor in tensors: - if math.prod(tensor.shape[:-1]) % 8 != 0 or tensor.shape[-1] % 16 != 0: - raise ValueError( - "FP8 execution requires the product of all dimensions except the last to be" - " divisible by 8 and the last dimension to be divisible by 16, but got tensor" - f" with dims={list(tensor.size())} (product of leading dims =" - f" {math.prod(tensor.shape[:-1])}, last dim = {tensor.shape[-1]})" - ) + NOTE: Relaxed — C++ cublas_gemm now handles alignment padding internally + for sequence packing with dynamic token counts. + """ + pass def is_bf16_compatible() -> bool: