diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4bfe06095b..30c360d6ee 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -14,6 +14,7 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, + RecipeState, get_align_size_for_quantization, ) from transformer_engine.pytorch.utils import ( @@ -1271,6 +1272,168 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) +def _tight_linear_dims(recipe_obj) -> Tuple[int, int, int, int, int, int]: + """Return tightest (M, N, K) for inference and training for a recipe. + + Constraints: + - cuBLAS needs 16B-aligned leading dims: 16 elts for FP8, 32 for FP4. + - Linear(K, N) fprop has lda=ldb=K, ldc=N (no M alignment needed). + - Training adds wgrad/dgrad with lda/ldb/ldc spanning M, N, K -> all must be aligned. + - Float8BlockScaling swizzle requires data_rows (first dim) % 4. + - NVFP4 RHT requires input first dim % 16 (subsumed by 32-elt alignment for training). + """ + if recipe_obj.delayed() or recipe_obj.float8_current_scaling(): + return (1, 16, 16, 16, 16, 16) + if recipe_obj.mxfp8(): + # MXFP8 needs each scaling dim >= 32-element block size; K for fprop, + # and M/N as well for wgrad/dgrad in training. + return (1, 16, 32, 32, 32, 32) + if recipe_obj.float8_block_scaling(): + # 1D block-scaled FP8 GEMM (native Hopper path) requires batch dim %8. + return (8, 128, 128, 16, 128, 128) + if recipe_obj.nvfp4(): + return (16, 32, 32, 32, 32, 32) + raise ValueError(f"Unknown recipe: {recipe_obj.__class__.__name__}") + + +@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("inference", [True, False], ids=["inference", "training"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=["bf16", "fp32"]) +def test_linear_tight_dims(recipe, inference, dtype): + """te.Linear with the tightest M/N/K per recipe, vs a pytorch baseline. + + Previously the Python assert_dim_for_fp8_exec rejected any M not divisible + by 8. With that guard removed, the C++ quantizers, swizzle kernel, and + cuBLAS are the real source of dimension constraints — they are looser and + recipe-specific. This test exercises the tightest shape each recipe should + accept and compares against a high-precision baseline. + + Both sides see the same dequantized input/weight, so the only difference is + FP8/FP4 tensor-core vs bf16/fp32 tensor-core accumulation roundoff. + """ + if recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip(f"{dtype} not supported for this NVFP4 recipe") + if not inference and recipe.fp4_quant_bwd_grad.random_hadamard_transform: + # The reference baseline cannot express the Hadamard transform that + # TE's wgrad GEMM applies to grad_output and input; their product is + # grad^T · H^T · H · x which is only ≈ grad^T · x after quantization + # noise averages out. + pytest.skip("NVFP4 with RHT: backward comparison needs Hadamard-aware baseline") + + # Seed deterministically so outputs don't depend on pytest parametrization order. + torch.manual_seed(0) + + m_inf, n_inf, k_inf, m_tr, n_tr, k_tr = _tight_linear_dims(recipe) + M, N, K = (m_inf, n_inf, k_inf) if inference else (m_tr, n_tr, k_tr) + + device = "cuda" + + # For the fp32 path, the baseline must stay in full FP32 — the default TF32 + # tensor cores would drop 13 mantissa bits and diverge far beyond the FP32 + # tolerance we want to check. + prev_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + prev_tf32_cudnn = torch.backends.cudnn.allow_tf32 + if dtype == torch.float32: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + try: + te_linear = Linear(K, N, bias=False, params_dtype=dtype, device=device) + torch_linear = torch.nn.Linear(K, N, bias=False, device=device, dtype=dtype) + + # Forward-path quantizers: [input, weight, output]. + fwd_state = RecipeState.create(recipe, mode="forward", num_quantizers=3) + input_quantizer, weight_quantizer, _ = fwd_state.make_quantizers() + + # Share weights: TE gets the raw weight (it quantizes internally); the + # baseline gets dequantize(quantize(W)) so both do the same matmul. + W = torch.randn(N, K, dtype=dtype, device=device) + with torch.no_grad(): + te_linear.weight.copy_(W) + torch_linear.weight.copy_(weight_quantizer(W).dequantize()) + + # Input: feed x_qdq to BOTH sides. TE quantizes internally (near-idempotent + # on an already-quantized tensor). Feeding raw x only to TE would let + # current-scaling recipes recompute a slightly different scale from the + # already-quantized amax and diverge from the baseline. + x_raw = torch.randn(M, K, dtype=dtype, device=device) + x_qdq = input_quantizer(x_raw).dequantize() + + requires_grad = not inference + x_te = x_qdq.clone().detach().requires_grad_(requires_grad) + x_ref = x_qdq.clone().detach().requires_grad_(requires_grad) + + if inference: + with autocast(enabled=True, recipe=recipe), torch.no_grad(): + te_out = te_linear(x_te) + else: + with autocast(enabled=True, recipe=recipe): + te_out = te_linear(x_te) + ref_out = torch_linear(x_ref) + torch.cuda.synchronize() + + # Tolerance sized for tensor-core accumulation differences between the + # FP8/FP4 path (TE) and the high-precision reference. Both sides + # multiply the same dequantized values, but the FP8 tensor-core tiles + # reduce in a different order than the bf16/fp32 tensor core. Empirical + # values measured across Hopper (H100) and Blackwell (B200): + # - BF16 output: ~1 BF16 ULP per element (up to ~4 on cancelling sums). + # - FP32 output w/ full FP32 accumulator: ~1e-3 absolute on Hopper + # (fast-accumulator FP8 GEMM), ~1e-5 on Blackwell. + # - FP32 output, Float8BlockScaling on Hopper: effectively BF16 + # precision — the native 1D block-scaled FP8 GEMM uses a + # lower-precision block accumulator. + if dtype == torch.bfloat16: + tols = dict(rtol=1.6e-2, atol=3e-2) + elif recipe.float8_block_scaling(): + tols = dict(rtol=1.6e-2, atol=3e-2) + else: + tols = dict(rtol=1e-3, atol=5e-3) + torch.testing.assert_close(te_out, ref_out, **tols) + + if not inference: + # Quantize+dequantize grad_output so both paths see the same signal. + bwd_state = RecipeState.create(recipe, mode="backward", num_quantizers=2) + grad_output_quantizer = bwd_state.make_quantizers()[0] + grad_raw = torch.randn_like(te_out) + grad_qdq = grad_output_quantizer(grad_raw).dequantize() + + te_out.backward(grad_qdq) + ref_out.backward(grad_qdq) + torch.cuda.synchronize() + + torch.testing.assert_close(x_te.grad, x_ref.grad, **tols) + torch.testing.assert_close(te_linear.weight.grad, torch_linear.weight.grad, **tols) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32_matmul + torch.backends.cudnn.allow_tf32 = prev_tf32_cudnn + + +def test_mxfp8_scale_shape_partial_block(): + """MXFP8 Python tensor helpers must ceildiv the scale dim, not floor-div. + + The C++ quantizer uses DIVUP so a 16-element trailing partial block gets + one scale; if the Python side floor-divides instead, scale tensors + collapse to zero size. This test exercises the newly-allowed last_dim=16 + path that the end-to-end GEMM test can't hit (the MXFP8 GEMM kernel + itself still requires K >= 32). + """ + if not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=True) + + rowwise = quantizer.get_scale_shape((32, 16), columnwise=False) + columnwise = quantizer.get_scale_shape((32, 16), columnwise=True) + assert all(d > 0 for d in rowwise), f"rowwise scale collapsed: {rowwise}" + assert all(d > 0 for d in columnwise), f"columnwise scale collapsed: {columnwise}" + + empty = quantizer.make_empty((32, 16), dtype=torch.bfloat16, device="cuda") + assert empty._rowwise_scale_inv.numel() > 0 + assert empty._columnwise_scale_inv.numel() > 0 + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 144aea1a07..7105c8fb2e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -92,6 +92,15 @@ struct GemmParam { int ldb = 0; // B column strides }; +// Minimum number of elements for 16-byte alignment, given a data type. +// cuBLAS requires (dim * typeSize) % 16 == 0 for FP8 tensor core usage, +// i.e. dim % (128 / typeBits) == 0. +constexpr size_t kAlignmentBytes = 16; + +size_t min_alignment_elements(transformer_engine::DType dtype) { + return kAlignmentBytes * 8 / transformer_engine::typeToNumBits(dtype); +} + /* Populate parameters for cuBLAS GEMM * * cuBLAS follows the BLAS convention of column-major ordering. This @@ -157,9 +166,14 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } if (is_fp8_dtype(ret.Atype)) { - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK(ret.lda % align == 0, + "FP8 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -176,6 +190,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; + + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK((ret.lda % align) == 0, + "NVFP4 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe. // Note: Row-wise and column-wise data are scaled along different @@ -191,6 +214,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = is_A_transposed ? k : m; + + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK((ret.lda % align) == 0, + "MXFP8 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. @@ -205,13 +237,22 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.lda % 16) == 0, - "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK((ret.lda % align) == 0, + "Block-scaled FP8 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. NVTE_CHECK((m % 8) == 0, - "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Block-scaled FP8 GEMM requires m to be divisible by 8," + " got m=", + m, " (n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else { NVTE_ERROR("A has unsupported scaling mode"); } @@ -247,10 +288,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.ldb = is_B_transposed ? k : n; } - if (is_fp8_dtype(ret.Atype)) { - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + if (is_fp8_dtype(ret.Btype)) { + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK(ret.ldb % align == 0, + "FP8 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else if (nvfp4) { if (is_B_transposed) { @@ -265,6 +311,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; + + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK((ret.ldb % align) == 0, + "NVFP4 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); @@ -276,6 +331,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; + + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK((ret.ldb % align) == 0, + "MXFP8 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. @@ -290,14 +354,21 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; - // Requirements from - // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.ldb % 16) == 0, - "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK((ret.ldb % align) == 0, + "Block-scaled FP8 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { - // Observed this requirement only present for B tensor is 1D quantized. NVTE_CHECK((n % 8) == 0, - "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Block-scaled FP8 GEMM requires n to be divisible by 8 for 1D block scaling," + " got n=", + n, " (m=", m, ", k=", k, + ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else { NVTE_ERROR("B has unsupported scaling mode"); @@ -765,10 +836,21 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, - "Unable to find suitable cuBLAS GEMM algorithm"); - NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); + if (status == CUBLAS_STATUS_NOT_SUPPORTED) { + returnedResults = 0; + } else { + NVTE_CHECK_CUBLAS(status); + } + NVTE_CHECK(returnedResults != 0, + "Unable to find suitable cuBLAS GEMM algorithm" + " (m=", + m, ", n=", n, ", k=", k, ", A.scaling_mode=", to_string(inputA->scaling_mode), + ", B.scaling_mode=", to_string(inputB->scaling_mode), + ")." + " This may be caused by unsupported tensor dimensions for the" + " current quantization recipe." + " Set CUBLASLT_LOG_LEVEL=5 to get the exact reason from cuBLAS." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */ diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index 90bc3985a4..59bdeb0e01 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -141,7 +141,10 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui alignof(uint4), " bytes"); NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); - NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors"); + NVTE_CHECK(data_rows % 4 == 0, + "Block scaling swizzle requires data_rows to be divisible by 4," + " got data_rows=", + data_rows); const uint32_t tiles_x = DIVUP(data_cols, 128u); const uint32_t tiles_y = DIVUP(data_rows, 128u); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..57cb416b3c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1346,9 +1346,6 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1665,10 +1662,6 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); - NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); - std::vector scale_shape; bool rowwise_usage = !columnwise; @@ -1676,11 +1669,11 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s if (rowwise_usage) { // rowwise scaling factor shape size_t sinv0 = roundup(numel / last_dim, 128); - size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); + size_t sinv1 = roundup(ceildiv(last_dim, MXFP8_BLOCK_SIZE), 4); scale_shape = {sinv0, sinv1}; } else { // columnwise scaling factor shape - size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + size_t sinv0 = roundup(ceildiv(numel / last_dim, MXFP8_BLOCK_SIZE), 4); size_t sinv1 = roundup(last_dim, 128); scale_shape = {sinv0, sinv1}; } @@ -1739,11 +1732,17 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", - NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, - "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); + NVTE_CHECK(flat_last_dim % 2 == 0, + "NVFP4 requires the last tensor dimension to be divisible by 2," + " got tensor with shape ", + shape, " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); + if (this->with_rht) { + NVTE_CHECK(flat_first_dim % 16 == 0, + "NVFP4 with random Hadamard transform requires the" + " product of all dimensions except the last to be divisible by 16," + " got tensor with shape ", + shape, " (flat_first_dim=", flat_first_dim, ")"); + } const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -2438,12 +2437,6 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); auto flat_first_dim = numel / last_dim; - NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", - NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); - NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, - "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); - std::vector scale_shape; bool rowwise_usage = !columnwise; @@ -2451,12 +2444,12 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s if (rowwise_usage) { // rowwise scaling factor shape size_t sinv0 = roundup(flat_first_dim, 128); - size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + size_t sinv1 = roundup(ceildiv(last_dim, NVFP4_BLOCK_SIZE), 4); scale_shape = {sinv0, sinv1}; } else { // columnwise scaling factor shape size_t sinv0 = roundup(last_dim, 128); - size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); + size_t sinv1 = roundup(ceildiv(flat_first_dim, NVFP4_BLOCK_SIZE), 4); scale_shape = {sinv0, sinv1}; } return scale_shape; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index a0d4ac3530..64586bf1d9 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1116,7 +1116,7 @@ def _start_all_gather_fp8_blockwise( raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") # Fall back to high-precision all-gather if FP8 is not supported - if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: + if not quantizer.supports_quantized_allgather(inp) or quantizer.block_scaling_dim != 1: warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): inp = inp.dequantize(dtype=dtype) # Dequantize if needed @@ -1368,7 +1368,7 @@ def _all_gather_nvfp4( if ( not isinstance(inp, NVFP4TensorStorage) and quantizer is not None - and not quantizer.is_quantizable(inp) + and not quantizer.supports_quantized_allgather(inp) ): warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): @@ -1542,7 +1542,7 @@ def _all_gather_mxfp8( if ( not isinstance(inp, MXFP8TensorStorage) and quantizer is not None - and not quantizer.is_quantizable(inp) + and not quantizer.supports_quantized_allgather(inp) ): warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f26faade0a..f9515e5f66 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -30,7 +30,6 @@ ) from ..quantization import FP8GlobalStateManager from ..utils import ( - assert_dim_for_fp8_exec, cast_if_needed, clear_tensor_data, divide, @@ -162,8 +161,6 @@ def forward( assert inp_shape[-1] == in_features, "GEMM not possible" inp = inp.view((-1, in_features)) inputmat = inp - if fp8: - assert_dim_for_fp8_exec(inputmat, weight) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a8d6e2e609..923e9cb424 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -41,7 +41,6 @@ get_default_init_method, init_method_constant, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -344,8 +343,6 @@ def _forward( in_features, inp_shape = ln_weight.numel(), inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) - if fp8: - assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 12339e7772..02f3b1431a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -35,7 +35,6 @@ init_method_constant, requires_grad, needs_quantized_gemm, - assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, get_nvtx_range_context, @@ -176,7 +175,6 @@ def _linear_forward_impl( inputmat_total = None # Input tensor to pass to GEMM (gathered) own_quantized_input = False if fp8: - assert_dim_for_fp8_exec(inputmat, weight) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..59689d73c9 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -363,11 +363,14 @@ def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False - def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument - """Whether tensor supports quantized all-gather - - Consider a less misleading function name. - + # pylint: disable-next=unused-argument + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: + """Whether tensor shape supports quantized all-gather. + + When False, the distributed all-gather falls back to gathering + in high precision and quantizing afterward. This is needed when + the local shard's shape would cause scaling factor blocks to + span across GPU boundaries. """ return True diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..53a686e743 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -191,8 +191,8 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: colwise_shape.extend(shape[:-1]) return tuple(colwise_shape) - def is_quantizable(self, inp: torch.Tensor) -> bool: - """Returns whether or not given inp can be quantized""" + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: + """Whether tensor shape supports quantized all-gather.""" shape = inp.size() if len(shape) < 2: return False diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..2d6bf838e9 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -86,13 +86,18 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def is_quantizable(self, inp: torch.Tensor) -> bool: - """Returns whether or not given inp can be quantized""" + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: + """Whether tensor shape supports quantized all-gather. + + For distributed all-gather with columnwise scaling, the first + dimension must be aligned to the block size so that no scaling + block spans across GPU boundaries. + """ if inp.ndim < 2: return False - if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0: + if inp.shape[-1] % 16 != 0: return False - if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0: + if self.columnwise_usage and math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0: return False return True @@ -111,21 +116,19 @@ def make_empty( device = torch.device("cuda") assert ( - shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 - and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0 - ), ( - f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by" - f" {MXFP8_BLOCK_SCALING_SIZE}" - ) + shape[-1] % 16 == 0 + ), f"Incorrect shape {shape} for MXFP8. Last dimension must be divisible by 16." - # Allocate FP8 data + # Allocate FP8 data. Use ceil-division for the scaling dim so partial + # trailing blocks (allowed by the relaxed last_dim % 16 constraint) get + # one scale each, matching the C++ quantizer (DIVUP in quantizer.cpp). data = None scale_inv = None if self.rowwise_usage: data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) scale_inv = torch.empty( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), 4), dtype=torch.uint8, device=device, pin_memory=pin_memory, @@ -139,7 +142,9 @@ def make_empty( shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) columnwise_scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple( + math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), 4 + ), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, @@ -189,18 +194,23 @@ def get_scale_shape( Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout """ + # Use ceil-division (mirroring C++ DIVUP in quantizer.cpp) so partial + # trailing blocks from the relaxed last_dim % 16 constraint get one + # scale each instead of collapsing to zero. if columnwise: - # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # Columnwise: scale_inv shape is [ceildiv(prod(shape[:-1]), BLOCK_SIZE), shape[-1]] # with padding to multiples of [4, 128] return ( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple( + math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), 4 + ), round_up_to_nearest_multiple(shape[-1], 128), ) - # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # Rowwise: scale_inv shape is [prod(shape[:-1]), ceildiv(shape[-1], BLOCK_SIZE)] # with padding to multiples of [128, 4] return ( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), 4), ) def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: @@ -558,14 +568,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): shape = args[1] first_dim = math.prod(shape[:-1]) last_dim = shape[-1] - if ( - first_dim % MXFP8_BLOCK_SCALING_SIZE != 0 - or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0 - ): + # Fall back to high-precision for shapes the quantizer cannot + # represent. Only last_dim % 16 is required (matches the relaxed + # C++ quantizer); partial trailing blocks use ceildiv. + if last_dim % 16 != 0: return super().__torch_dispatch__(func, types, args, kwargs) - rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE] + rowwise_scale_inv_shape = [ + first_dim, + math.ceil(last_dim / MXFP8_BLOCK_SCALING_SIZE), + ] columnwise_scale_inv_shape = [ - first_dim // MXFP8_BLOCK_SCALING_SIZE, + math.ceil(first_dim / MXFP8_BLOCK_SCALING_SIZE), last_dim, ] if tensor._rowwise_data is not None: @@ -649,15 +662,22 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m shape = self.shape if self._with_gemm_swizzled_scales: raise NotImplementedError( - "FSDP2 is only supported for MXFP8Tensors with compact scales" + "FSDP2 is only supported for MXFP8Tensors with compact scales." ) + flattened_in_shape0 = math.prod(shape[:-1]) if rowwise_scale_inv is not None: # Remove padding from rowwise scale_inv - flattened_in_shape0 = math.prod(shape[:-1]) if rowwise_scale_inv.size(0) != flattened_in_shape0: rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] if columnwise_scale_inv is not None: # Remove padding from columnwise scale_inv + if flattened_in_shape0 % MXFP8_BLOCK_SCALING_SIZE != 0: + raise NotImplementedError( + "FSDP2 during the backward pass is only supported for MXFP8Tensors with " + f"the flattened first dimension divisible by {MXFP8_BLOCK_SCALING_SIZE} " + f"and got tensor with shape {shape} with flattened first dimension " + f"{flattened_in_shape0}." + ) flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE if columnwise_scale_inv.size(0) != flattened_in_shape0: columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0] diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index eb514d3a9e..557420bbd9 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -210,13 +210,18 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def is_quantizable(self, inp: torch.Tensor) -> bool: - """Returns whether or not given inp can be quantized""" + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: + """Whether tensor shape supports quantized all-gather. + + For distributed all-gather with columnwise scaling, the first + dimension must be aligned to the block size so that no scaling + block spans across GPU boundaries. + """ if inp.ndim < 2: return False - if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: + if inp.shape[-1] % 32 != 0: return False - if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + if self.columnwise_usage and math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: return False return True @@ -303,16 +308,9 @@ def make_empty( if device is None: device = torch.device("cuda") - assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" - ) - - flat_first_dim = math.prod(shape[:-1]) - assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" - ) + assert ( + shape[-1] % 32 == 0 + ), f"Incorrect shape {shape} for NVFP4. Last dimension must be divisible by 32." # Allocate FP4 data data = None @@ -339,7 +337,7 @@ def make_empty( if self.columnwise_usage: # enforce 2D shape to avoid [S, B, H] shape and B and be 1 # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - shape_2d = tuple([flat_first_dim, shape[-1]]) + shape_2d = (math.prod(shape[:-1]), shape[-1]) columnwise_data = torch.empty( self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), dtype=torch.uint8, diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index a76f205acc..02237d8419 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -471,24 +471,6 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return tensor.to(dtype=dtype) -def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: - """Check if tensor dimensions are supported for FP8 TN GEMM""" - return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 - - -def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: - """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" - - for tensor in tensors: - if math.prod(tensor.shape[:-1]) % 8 != 0 or tensor.shape[-1] % 16 != 0: - raise ValueError( - "FP8 execution requires the product of all dimensions except the last to be" - " divisible by 8 and the last dimension to be divisible by 16, but got tensor" - f" with dims={list(tensor.size())} (product of leading dims =" - f" {math.prod(tensor.shape[:-1])}, last dim = {tensor.shape[-1]})" - ) - - def is_bf16_compatible() -> bool: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher.