From bd206afaa470eeb40d5e499adf843c0424793ace Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Apr 2026 15:58:37 -0700 Subject: [PATCH 01/18] Remove assert_dim_for_fp8_exec dimension guard This one-size-fits-all Python guard checked dimensions before the recipe was known, rejecting valid shapes. Dimension validation is handled per-recipe in the C++ quantizer where requirements are known. Signed-off-by: Przemek Tredak --- .../pytorch/module/layernorm_linear.py | 3 --- .../pytorch/module/layernorm_mlp.py | 3 --- transformer_engine/pytorch/module/linear.py | 2 -- transformer_engine/pytorch/utils.py | 17 ----------------- 4 files changed, 25 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index dc021ca6b7..e30871ecac 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -28,7 +28,6 @@ ) from ..quantization import FP8GlobalStateManager from ..utils import ( - assert_dim_for_fp8_exec, cast_if_needed, clear_tensor_data, divide, @@ -159,8 +158,6 @@ def forward( assert inp_shape[-1] == in_features, "GEMM not possible" inp = inp.view((-1, in_features)) inputmat = inp - if fp8: - assert_dim_for_fp8_exec(inputmat, weight) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a99de65c4a..c125683ce2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -39,7 +39,6 @@ get_default_init_method, init_method_constant, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -340,8 +339,6 @@ def _forward( in_features, inp_shape = ln_weight.numel(), inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) - if fp8: - assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8510f6cf8f..07bf60065f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -33,7 +33,6 @@ init_method_constant, requires_grad, needs_quantized_gemm, - assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, get_nvtx_range_context, @@ -179,7 +178,6 @@ def forward( inputmat_total = None # Input tensor to pass to GEMM (gathered) own_quantized_input = False if fp8: - assert_dim_for_fp8_exec(inputmat, weight) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index a76f205acc..ac64752ac7 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -471,23 +471,6 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return tensor.to(dtype=dtype) -def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: - """Check if tensor dimensions are supported for FP8 TN GEMM""" - return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 - - -def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: - """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" - - for tensor in tensors: - if math.prod(tensor.shape[:-1]) % 8 != 0 or tensor.shape[-1] % 16 != 0: - raise ValueError( - "FP8 execution requires the product of all dimensions except the last to be" - " divisible by 8 and the last dimension to be divisible by 16, but got tensor" - f" with dims={list(tensor.size())} (product of leading dims =" - f" {math.prod(tensor.shape[:-1])}, last dim = {tensor.shape[-1]})" - ) - def is_bf16_compatible() -> bool: """Replaces torch.cuda.is_bf16_compatible() with an explicit From d2636a86db9aa994460bd9bc151d8a45727095c4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Apr 2026 16:00:02 -0700 Subject: [PATCH 02/18] Relax dimension requirements for MXFP8, NVFP4, and block scaling - MXFP8: only require last dim divisible by 16 (was: both dims by 32). The kernel handles partial blocks via bounds checks and TMA zero-padding. Scale tensors are over-allocated with roundup alignment. - NVFP4: only require last dim divisible by 32 (was: both dims by 16). 4-bit data needs 32 elements for 16-byte alignment. - Float8BlockScaling swizzle: remove data_rows%4 assertion. The kernel already handles non-aligned rows via DIVUP and OOB zero-fill. - Fix integer truncation in MXFP8/NVFP4 get_scale_shape to use ceildiv instead of plain division, ensuring correct scale allocation for non-block-aligned dimensions. Signed-off-by: Przemek Tredak --- .../common/swizzle/swizzle_block_scaling.cu | 1 - transformer_engine/pytorch/csrc/quantizer.cpp | 38 +++++++++---------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index 90bc3985a4..467389c25e 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -141,7 +141,6 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui alignof(uint4), " bytes"); NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); - NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors"); const uint32_t tiles_x = DIVUP(data_cols, 128u); const uint32_t tiles_y = DIVUP(data_rows, 128u); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..52bbe9a4fc 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1346,9 +1346,10 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); + NVTE_CHECK(flat_last_dim % 16 == 0, + "MXFP8 requires the last tensor dimension to be divisible by 16," + " got tensor with shape ", shape, + " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1665,9 +1666,9 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); - NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); + NVTE_CHECK(last_dim % 16 == 0, + "MXFP8 requires the last tensor dimension to be divisible by 16," + " got tensor with shape ", shape); std::vector scale_shape; @@ -1676,11 +1677,11 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s if (rowwise_usage) { // rowwise scaling factor shape size_t sinv0 = roundup(numel / last_dim, 128); - size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); + size_t sinv1 = roundup(ceildiv(last_dim, MXFP8_BLOCK_SIZE), 4); scale_shape = {sinv0, sinv1}; } else { // columnwise scaling factor shape - size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + size_t sinv0 = roundup(ceildiv(numel / last_dim, MXFP8_BLOCK_SIZE), 4); size_t sinv1 = roundup(last_dim, 128); scale_shape = {sinv0, sinv1}; } @@ -1739,11 +1740,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", - NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, - "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); + NVTE_CHECK(flat_last_dim % 32 == 0, + "NVFP4 requires the last tensor dimension to be divisible by 32," + " got tensor with shape ", shape, + " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -2438,11 +2438,9 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); auto flat_first_dim = numel / last_dim; - NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", - NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); - NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, - "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); + NVTE_CHECK(last_dim % 32 == 0, + "NVFP4 requires the last tensor dimension to be divisible by 32," + " got tensor with shape ", shape); std::vector scale_shape; @@ -2451,12 +2449,12 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s if (rowwise_usage) { // rowwise scaling factor shape size_t sinv0 = roundup(flat_first_dim, 128); - size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + size_t sinv1 = roundup(ceildiv(last_dim, NVFP4_BLOCK_SIZE), 4); scale_shape = {sinv0, sinv1}; } else { // columnwise scaling factor shape size_t sinv0 = roundup(last_dim, 128); - size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); + size_t sinv1 = roundup(ceildiv(flat_first_dim, NVFP4_BLOCK_SIZE), 4); scale_shape = {sinv0, sinv1}; } return scale_shape; From 6c3a993f338ee7cd346cbb39d7a29f0b669d74eb Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Apr 2026 16:00:48 -0700 Subject: [PATCH 03/18] Update is_quantizable() and make_empty() to match relaxed requirements - MXFP8: only check last dim divisible by 16 (removed first-dim check) - NVFP4: only check last dim divisible by 32 (removed first-dim check) Signed-off-by: Przemek Tredak --- .../pytorch/tensor/mxfp8_tensor.py | 13 ++++--------- .../pytorch/tensor/nvfp4_tensor.py | 16 ++++------------ 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..5e1a60de90 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -90,9 +90,7 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" if inp.ndim < 2: return False - if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0: - return False - if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0: + if inp.shape[-1] % 16 != 0: return False return True @@ -110,12 +108,9 @@ def make_empty( if device is None: device = torch.device("cuda") - assert ( - shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 - and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0 - ), ( - f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by" - f" {MXFP8_BLOCK_SCALING_SIZE}" + assert shape[-1] % 16 == 0, ( + f"Incorrect shape {shape} for MXFP8." + f" Last dimension must be divisible by 16." ) # Allocate FP8 data diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index eb514d3a9e..95b5a997b8 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -214,9 +214,7 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" if inp.ndim < 2: return False - if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: - return False - if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + if inp.shape[-1] % 32 != 0: return False return True @@ -303,15 +301,9 @@ def make_empty( if device is None: device = torch.device("cuda") - assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" - ) - - flat_first_dim = math.prod(shape[:-1]) - assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" + assert shape[-1] % 32 == 0, ( + f"Incorrect shape {shape} for NVFP4." + f" Last dimension must be divisible by 32." ) # Allocate FP4 data From 477d6a15c788f0da98a3ae11c86c467a493b7ed8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Apr 2026 16:07:01 -0700 Subject: [PATCH 04/18] Add NVFP4 first-dim guard for columnwise usage (Hadamard requirement) NVFP4 backward pass uses the Hadamard transform which requires num_rows % 16 == 0. Add a clear error message at the quantizer level when columnwise_usage is enabled, instead of letting the user hit the raw Hadamard kernel assertion. Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 7 +++++++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 9 ++++++++- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 9 ++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 52bbe9a4fc..5089427314 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1744,6 +1744,13 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve "NVFP4 requires the last tensor dimension to be divisible by 32," " got tensor with shape ", shape, " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); + if (this->with_rht) { + NVTE_CHECK(flat_first_dim % 16 == 0, + "NVFP4 with random Hadamard transform requires the" + " product of all dimensions except the last to be divisible by 16," + " got tensor with shape ", shape, + " (flat_first_dim=", flat_first_dim, ")"); + } const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5e1a60de90..5b6217a93e 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -87,11 +87,18 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: - """Returns whether or not given inp can be quantized""" + """Whether tensor can be quantized and distributed via all-gather. + + For distributed all-gather with columnwise scaling, the first + dimension must be aligned to the block size so that no scaling + block spans across GPU boundaries. + """ if inp.ndim < 2: return False if inp.shape[-1] % 16 != 0: return False + if self.columnwise_usage and math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0: + return False return True def make_empty( diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 95b5a997b8..05190e6e19 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -211,11 +211,18 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: - """Returns whether or not given inp can be quantized""" + """Whether tensor can be quantized and distributed via all-gather. + + For distributed all-gather with columnwise scaling, the first + dimension must be aligned to the block size so that no scaling + block spans across GPU boundaries. + """ if inp.ndim < 2: return False if inp.shape[-1] % 32 != 0: return False + if self.columnwise_usage and math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + return False return True def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: From 82857e25333a78ce057ea3660186013988c62ad3 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Apr 2026 16:25:59 -0700 Subject: [PATCH 05/18] Rename is_quantizable to supports_quantized_allgather MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old name was misleading — it doesn't check whether a tensor can be quantized in general, but whether a local shard's shape supports quantized all-gather without scaling factor blocks spanning across GPU boundaries. The new name reflects the actual purpose. Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/distributed.py | 6 +++--- transformer_engine/pytorch/quantized_tensor.py | 10 ++++++---- .../pytorch/tensor/float8_blockwise_tensor.py | 4 ++-- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4 ++-- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 4 ++-- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index b80e58fe20..23002c2398 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1119,7 +1119,7 @@ def _start_all_gather_fp8_blockwise( raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") # Fall back to high-precision all-gather if FP8 is not supported - if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: + if not quantizer.supports_quantized_allgather(inp) or quantizer.block_scaling_dim != 1: warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): inp = inp.dequantize(dtype=dtype) # Dequantize if needed @@ -1371,7 +1371,7 @@ def _all_gather_nvfp4( if ( not isinstance(inp, NVFP4TensorStorage) and quantizer is not None - and not quantizer.is_quantizable(inp) + and not quantizer.supports_quantized_allgather(inp) ): warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): @@ -1545,7 +1545,7 @@ def _all_gather_mxfp8( if ( not isinstance(inp, MXFP8TensorStorage) and quantizer is not None - and not quantizer.is_quantizable(inp) + and not quantizer.supports_quantized_allgather(inp) ): warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..7c649f7ca0 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -363,11 +363,13 @@ def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False - def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument - """Whether tensor supports quantized all-gather - - Consider a less misleading function name. + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument + """Whether tensor shape supports quantized all-gather. + When False, the distributed all-gather falls back to gathering + in high precision and quantizing afterward. This is needed when + the local shard's shape would cause scaling factor blocks to + span across GPU boundaries. """ return True diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..53a686e743 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -191,8 +191,8 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: colwise_shape.extend(shape[:-1]) return tuple(colwise_shape) - def is_quantizable(self, inp: torch.Tensor) -> bool: - """Returns whether or not given inp can be quantized""" + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: + """Whether tensor shape supports quantized all-gather.""" shape = inp.size() if len(shape) < 2: return False diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5b6217a93e..534fe04580 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -86,8 +86,8 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def is_quantizable(self, inp: torch.Tensor) -> bool: - """Whether tensor can be quantized and distributed via all-gather. + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: + """Whether tensor shape supports quantized all-gather. For distributed all-gather with columnwise scaling, the first dimension must be aligned to the block size so that no scaling diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 05190e6e19..cc85ba3f59 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -210,8 +210,8 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def is_quantizable(self, inp: torch.Tensor) -> bool: - """Whether tensor can be quantized and distributed via all-gather. + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: + """Whether tensor shape supports quantized all-gather. For distributed all-gather with columnwise scaling, the first dimension must be aligned to the block size so that no scaling From c0c62c8e68b0dd5115011ed7ee2828c03256620b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Apr 2026 17:21:49 -0700 Subject: [PATCH 06/18] Improve FP8 GEMM dimension check error messages Include actual dimension values (m, n, k, lda, ldb) in error messages so users can trace which tensor dimension is misaligned. Reference the cuBLAS documentation for FP8 alignment requirements. Signed-off-by: Przemek Tredak --- .../common/gemm/cublaslt_gemm.cu | 72 ++++++++++++++----- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 144aea1a07..8d0a1be37e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -157,9 +157,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } if (is_fp8_dtype(ret.Atype)) { - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + "FP8 GEMM requires leading dimension of A to be divisible by 16," + " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + " Ensure all tensor dimensions meet FP8 alignment requirements" + " (see https://docs.nvidia.com/cuda/cublas/#tensor-core-usage)."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -176,6 +178,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; + + // NVFP4 is 4-bit, so 32 elements = 16 bytes for alignment. + NVTE_CHECK((ret.lda % 32) == 0, + "NVFP4 GEMM requires leading dimension of A to be divisible by 32," + " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe. // Note: Row-wise and column-wise data are scaled along different @@ -191,6 +199,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = is_A_transposed ? k : m; + + NVTE_CHECK((ret.lda % 16) == 0, + "MXFP8 GEMM requires leading dimension of A to be divisible by 16," + " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. @@ -205,13 +218,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage NVTE_CHECK((ret.lda % 16) == 0, - "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Block-scaled FP8 GEMM requires leading dimension of A to be divisible by 16," + " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. NVTE_CHECK((m % 8) == 0, - "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Block-scaled FP8 GEMM requires m to be divisible by 8," + " got m=", m, " (n=", n, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else { NVTE_ERROR("A has unsupported scaling mode"); } @@ -248,9 +264,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } if (is_fp8_dtype(ret.Atype)) { - // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + "FP8 GEMM requires leading dimension of B to be divisible by 16," + " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + " Ensure all tensor dimensions meet FP8 alignment requirements" + " (see https://docs.nvidia.com/cuda/cublas/#tensor-core-usage)."); } } else if (nvfp4) { if (is_B_transposed) { @@ -265,6 +283,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; + + // NVFP4 is 4-bit, so 32 elements = 16 bytes for alignment. + NVTE_CHECK((ret.ldb % 32) == 0, + "NVFP4 GEMM requires leading dimension of B to be divisible by 32," + " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); @@ -276,6 +300,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; + + NVTE_CHECK((ret.ldb % 16) == 0, + "MXFP8 GEMM requires leading dimension of B to be divisible by 16," + " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. @@ -290,14 +319,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; - // Requirements from - // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage NVTE_CHECK((ret.ldb % 16) == 0, - "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Block-scaled FP8 GEMM requires leading dimension of B to be divisible by 16," + " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { - // Observed this requirement only present for B tensor is 1D quantized. NVTE_CHECK((n % 8) == 0, - "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Block-scaled FP8 GEMM requires n to be divisible by 8 for 1D block scaling," + " got n=", n, " (m=", m, ", k=", k, ")." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else { NVTE_ERROR("B has unsupported scaling mode"); @@ -765,10 +795,20 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, - "Unable to find suitable cuBLAS GEMM algorithm"); - NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); + if (status == CUBLAS_STATUS_NOT_SUPPORTED) { + returnedResults = 0; + } else { + NVTE_CHECK_CUBLAS(status); + } + NVTE_CHECK(returnedResults != 0, + "Unable to find suitable cuBLAS GEMM algorithm" + " (m=", m, ", n=", n, ", k=", k, + ", A.scaling_mode=", to_string(inputA->scaling_mode), + ", B.scaling_mode=", to_string(inputB->scaling_mode), ")." + " This may be caused by unsupported tensor dimensions for the" + " current quantization recipe." + " Set CUBLASLT_LOG_LEVEL=5 to get the exact reason from cuBLAS." + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */ From 326563b5902d09e6c61392252fe49cac2c6d0e29 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Apr 2026 17:49:15 -0700 Subject: [PATCH 07/18] fixup: restore swizzle data_rows%4 check with improved message Signed-off-by: Przemek Tredak --- transformer_engine/common/swizzle/swizzle_block_scaling.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index 467389c25e..cbdbb232da 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -141,6 +141,9 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui alignof(uint4), " bytes"); NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + NVTE_CHECK(data_rows % 4 == 0, + "Block scaling swizzle requires data_rows to be divisible by 4," + " got data_rows=", data_rows); const uint32_t tiles_x = DIVUP(data_cols, 128u); const uint32_t tiles_y = DIVUP(data_rows, 128u); From 09aa7ca52b4490a9e54e7b95a3cb4e6b92f98e36 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 15 Apr 2026 11:44:07 -0700 Subject: [PATCH 08/18] Unify GEMM error message format Remove "Ensure all tensor dimensions..." sentence and parentheses from cuBLAS documentation links for consistent error messages. Signed-off-by: Przemek Tredak --- transformer_engine/common/gemm/cublaslt_gemm.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8d0a1be37e..cb0550ca24 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -160,8 +160,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(ret.lda % 16 == 0, "FP8 GEMM requires leading dimension of A to be divisible by 16," " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." - " Ensure all tensor dimensions meet FP8 alignment requirements" - " (see https://docs.nvidia.com/cuda/cublas/#tensor-core-usage)."); + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -267,8 +266,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(ret.ldb % 16 == 0, "FP8 GEMM requires leading dimension of B to be divisible by 16," " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." - " Ensure all tensor dimensions meet FP8 alignment requirements" - " (see https://docs.nvidia.com/cuda/cublas/#tensor-core-usage)."); + " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else if (nvfp4) { if (is_B_transposed) { From c219ad0a66f5c5c750e77a422450e30eaa862dcc Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 15 Apr 2026 11:52:58 -0700 Subject: [PATCH 09/18] Use min_alignment_elements() for GEMM dimension checks Replace hardcoded alignment values (16, 32) with a helper that computes the minimum element alignment from the data type's bit width and the 16-byte cuBLAS alignment requirement. This ensures alignment checks stay correct if new data types are added. Signed-off-by: Przemek Tredak --- .../common/gemm/cublaslt_gemm.cu | 53 ++++++++++++------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index cb0550ca24..70a5d572fa 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -92,6 +92,15 @@ struct GemmParam { int ldb = 0; // B column strides }; +// Minimum number of elements for 16-byte alignment, given a data type. +// cuBLAS requires (dim * typeSize) % 16 == 0 for FP8 tensor core usage, +// i.e. dim % (128 / typeBits) == 0. +constexpr size_t kAlignmentBytes = 16; + +size_t min_alignment_elements(transformer_engine::DType dtype) { + return kAlignmentBytes * 8 / transformer_engine::typeToNumBits(dtype); +} + /* Populate parameters for cuBLAS GEMM * * cuBLAS follows the BLAS convention of column-major ordering. This @@ -157,8 +166,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } if (is_fp8_dtype(ret.Atype)) { - NVTE_CHECK(ret.lda % 16 == 0, - "FP8 GEMM requires leading dimension of A to be divisible by 16," + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK(ret.lda % align == 0, + "FP8 GEMM requires leading dimension of A to be divisible by ", align, "," " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } @@ -178,9 +188,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; - // NVFP4 is 4-bit, so 32 elements = 16 bytes for alignment. - NVTE_CHECK((ret.lda % 32) == 0, - "NVFP4 GEMM requires leading dimension of A to be divisible by 32," + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK((ret.lda % align) == 0, + "NVFP4 GEMM requires leading dimension of A to be divisible by ", align, "," " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { @@ -199,8 +209,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - NVTE_CHECK((ret.lda % 16) == 0, - "MXFP8 GEMM requires leading dimension of A to be divisible by 16," + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK((ret.lda % align) == 0, + "MXFP8 GEMM requires leading dimension of A to be divisible by ", align, "," " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { @@ -217,8 +228,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; - NVTE_CHECK((ret.lda % 16) == 0, - "Block-scaled FP8 GEMM requires leading dimension of A to be divisible by 16," + const auto align = min_alignment_elements(ret.Atype); + NVTE_CHECK((ret.lda % align) == 0, + "Block-scaled FP8 GEMM requires leading dimension of A to be divisible by ", align, "," " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. @@ -262,9 +274,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.ldb = is_B_transposed ? k : n; } - if (is_fp8_dtype(ret.Atype)) { - NVTE_CHECK(ret.ldb % 16 == 0, - "FP8 GEMM requires leading dimension of B to be divisible by 16," + if (is_fp8_dtype(ret.Btype)) { + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK(ret.ldb % align == 0, + "FP8 GEMM requires leading dimension of B to be divisible by ", align, "," " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } @@ -282,9 +295,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; - // NVFP4 is 4-bit, so 32 elements = 16 bytes for alignment. - NVTE_CHECK((ret.ldb % 32) == 0, - "NVFP4 GEMM requires leading dimension of B to be divisible by 32," + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK((ret.ldb % align) == 0, + "NVFP4 GEMM requires leading dimension of B to be divisible by ", align, "," " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { @@ -299,8 +312,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - NVTE_CHECK((ret.ldb % 16) == 0, - "MXFP8 GEMM requires leading dimension of B to be divisible by 16," + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK((ret.ldb % align) == 0, + "MXFP8 GEMM requires leading dimension of B to be divisible by ", align, "," " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { @@ -317,8 +331,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; - NVTE_CHECK((ret.ldb % 16) == 0, - "Block-scaled FP8 GEMM requires leading dimension of B to be divisible by 16," + const auto align = min_alignment_elements(ret.Btype); + NVTE_CHECK((ret.ldb % align) == 0, + "Block-scaled FP8 GEMM requires leading dimension of B to be divisible by ", align, "," " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { From 5162a6532e7921fed8e6c640ed863371247000d9 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 15 Apr 2026 12:01:24 -0700 Subject: [PATCH 10/18] Add test for te.Linear with small batch dimension (M=1) Verifies that the relaxed dimension checks allow small M values for recipes that support them on the forward pass. Signed-off-by: Przemek Tredak --- tests/pytorch/test_numerics.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4bfe06095b..1ca10c9946 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1271,6 +1271,39 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_linear_small_M(recipe): + """Test te.Linear with small batch dimension (M=1). + + Verifies that the relaxed dimension checks allow small M values + for recipes that support them. Previously assert_dim_for_fp8_exec + would reject any M not divisible by 8. + """ + from transformer_engine.common.recipe import Float8BlockScaling, NVFP4BlockScaling + if isinstance(recipe, (Float8BlockScaling, NVFP4BlockScaling)): + pytest.skip( + f"{recipe.__class__.__name__} does not support M=1" + " (block scaling swizzle / Hadamard transform requirements)" + ) + + hidden_size = 128 + te_linear = Linear( + hidden_size, + 4 * hidden_size, + bias=False, + params_dtype=torch.bfloat16, + device="cuda", + ) + + x = torch.randn(1, 1, hidden_size, dtype=torch.bfloat16, device="cuda") + + with autocast(enabled=True, recipe=recipe): + with torch.no_grad(): + out = te_linear(x) + torch.cuda.synchronize() + assert out.shape == (1, 1, 4 * hidden_size) + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) From a969fc7121ef03468c4a69168618cf6696c4d718 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 16 Apr 2026 09:08:26 -0700 Subject: [PATCH 11/18] Add test for te.Linear with tight per-recipe dimensions Each recipe gets its tightest viable (M, N, K) for both inference and training. The baseline pre-quantizes+dequantizes inputs and weights via the recipe's quantizers so the comparison holds at BF16 tolerance. Signed-off-by: Przemek Tredak --- tests/pytorch/test_numerics.py | 158 +++++++++++++++++++++++++++------ 1 file changed, 133 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1ca10c9946..109541a1cc 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -14,6 +14,7 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, + RecipeState, get_align_size_for_quantization, ) from transformer_engine.pytorch.utils import ( @@ -1271,37 +1272,144 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) -@pytest.mark.parametrize("recipe", fp8_recipes) -def test_linear_small_M(recipe): - """Test te.Linear with small batch dimension (M=1). +def _tight_linear_dims(recipe_obj) -> Tuple[int, int, int, int, int, int]: + """Return tightest (M, N, K) for inference and training for a recipe. - Verifies that the relaxed dimension checks allow small M values - for recipes that support them. Previously assert_dim_for_fp8_exec - would reject any M not divisible by 8. + Constraints: + - cuBLAS needs 16B-aligned leading dims: 16 elts for FP8, 32 for FP4. + - Linear(K, N) fprop has lda=ldb=K, ldc=N (no M alignment needed). + - Training adds wgrad/dgrad with lda/ldb/ldc spanning M, N, K -> all must be aligned. + - Float8BlockScaling swizzle requires data_rows (first dim) % 4. + - NVFP4 RHT requires input first dim % 16 (subsumed by 32-elt alignment for training). """ - from transformer_engine.common.recipe import Float8BlockScaling, NVFP4BlockScaling - if isinstance(recipe, (Float8BlockScaling, NVFP4BlockScaling)): - pytest.skip( - f"{recipe.__class__.__name__} does not support M=1" - " (block scaling swizzle / Hadamard transform requirements)" - ) + if recipe_obj.delayed() or recipe_obj.float8_current_scaling(): + return (1, 16, 16, 16, 16, 16) + if recipe_obj.mxfp8(): + # MXFP8 needs each scaling dim >= 32-element block size; K for fprop, + # and M/N as well for wgrad/dgrad in training. + return (1, 16, 32, 32, 32, 32) + if recipe_obj.float8_block_scaling(): + # 1D block-scaled FP8 GEMM (native Hopper path) requires batch dim %8. + return (8, 128, 128, 16, 128, 128) + if recipe_obj.nvfp4(): + return (16, 32, 32, 32, 32, 32) + raise ValueError(f"Unknown recipe: {recipe_obj.__class__.__name__}") - hidden_size = 128 - te_linear = Linear( - hidden_size, - 4 * hidden_size, - bias=False, - params_dtype=torch.bfloat16, - device="cuda", - ) - x = torch.randn(1, 1, hidden_size, dtype=torch.bfloat16, device="cuda") +@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("inference", [True, False], ids=["inference", "training"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=["bf16", "fp32"]) +def test_linear_tight_dims(recipe, inference, dtype): + """te.Linear with the tightest M/N/K per recipe, vs a pytorch baseline. + + Previously the Python assert_dim_for_fp8_exec rejected any M not divisible + by 8. With that guard removed, the C++ quantizers, swizzle kernel, and + cuBLAS are the real source of dimension constraints — they are looser and + recipe-specific. This test exercises the tightest shape each recipe should + accept and compares against a high-precision baseline. + + Both sides see the same dequantized input/weight, so the only difference is + FP8/FP4 tensor-core vs bf16/fp32 tensor-core accumulation roundoff. + """ + if recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip(f"{dtype} not supported for this NVFP4 recipe") + if ( + not inference + and recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + # The reference baseline cannot express the Hadamard transform that + # TE's wgrad GEMM applies to grad_output and input; their product is + # grad^T · H^T · H · x which is only ≈ grad^T · x after quantization + # noise averages out. + pytest.skip("NVFP4 with RHT: backward comparison needs Hadamard-aware baseline") + + # Seed deterministically so outputs don't depend on pytest parametrization order. + torch.manual_seed(0) - with autocast(enabled=True, recipe=recipe): + m_inf, n_inf, k_inf, m_tr, n_tr, k_tr = _tight_linear_dims(recipe) + M, N, K = (m_inf, n_inf, k_inf) if inference else (m_tr, n_tr, k_tr) + + device = "cuda" + + # For the fp32 path, the baseline must stay in full FP32 — the default TF32 + # tensor cores would drop 13 mantissa bits and diverge far beyond the FP32 + # tolerance we want to check. + prev_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + prev_tf32_cudnn = torch.backends.cudnn.allow_tf32 + if dtype == torch.float32: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + try: + te_linear = Linear(K, N, bias=False, params_dtype=dtype, device=device) + torch_linear = torch.nn.Linear(K, N, bias=False, device=device, dtype=dtype) + + # Forward-path quantizers: [input, weight, output]. + fwd_state = RecipeState.create(recipe, mode="forward", num_quantizers=3) + input_quantizer, weight_quantizer, _ = fwd_state.make_quantizers() + + # Share weights: TE gets the raw weight (it quantizes internally); the + # baseline gets dequantize(quantize(W)) so both do the same matmul. + W = torch.randn(N, K, dtype=dtype, device=device) with torch.no_grad(): - out = te_linear(x) - torch.cuda.synchronize() - assert out.shape == (1, 1, 4 * hidden_size) + te_linear.weight.copy_(W) + torch_linear.weight.copy_(weight_quantizer(W).dequantize()) + + # Input: feed x_qdq to BOTH sides. TE quantizes internally (near-idempotent + # on an already-quantized tensor). Feeding raw x only to TE would let + # current-scaling recipes recompute a slightly different scale from the + # already-quantized amax and diverge from the baseline. + x_raw = torch.randn(M, K, dtype=dtype, device=device) + x_qdq = input_quantizer(x_raw).dequantize() + + requires_grad = not inference + x_te = x_qdq.clone().detach().requires_grad_(requires_grad) + x_ref = x_qdq.clone().detach().requires_grad_(requires_grad) + + if inference: + with autocast(enabled=True, recipe=recipe), torch.no_grad(): + te_out = te_linear(x_te) + else: + with autocast(enabled=True, recipe=recipe): + te_out = te_linear(x_te) + ref_out = torch_linear(x_ref) + torch.cuda.synchronize() + + # Tolerance sized for tensor-core accumulation differences between the + # FP8/FP4 path (TE) and the high-precision reference. Both sides + # multiply the same dequantized values, but the FP8 tensor-core tiles + # reduce in a different order than the bf16/fp32 tensor core. Empirical + # values measured across Hopper (H100) and Blackwell (B200): + # - BF16 output: ~1 BF16 ULP per element (up to ~4 on cancelling sums). + # - FP32 output w/ full FP32 accumulator: ~1e-3 absolute on Hopper + # (fast-accumulator FP8 GEMM), ~1e-5 on Blackwell. + # - FP32 output, Float8BlockScaling on Hopper: effectively BF16 + # precision — the native 1D block-scaled FP8 GEMM uses a + # lower-precision block accumulator. + if dtype == torch.bfloat16: + tols = dict(rtol=1.6e-2, atol=3e-2) + elif recipe.float8_block_scaling(): + tols = dict(rtol=1.6e-2, atol=3e-2) + else: + tols = dict(rtol=1e-3, atol=5e-3) + torch.testing.assert_close(te_out, ref_out, **tols) + + if not inference: + # Quantize+dequantize grad_output so both paths see the same signal. + bwd_state = RecipeState.create(recipe, mode="backward", num_quantizers=2) + grad_output_quantizer = bwd_state.make_quantizers()[0] + grad_raw = torch.randn_like(te_out) + grad_qdq = grad_output_quantizer(grad_raw).dequantize() + + te_out.backward(grad_qdq) + ref_out.backward(grad_qdq) + torch.cuda.synchronize() + + torch.testing.assert_close(x_te.grad, x_ref.grad, **tols) + torch.testing.assert_close(te_linear.weight.grad, torch_linear.weight.grad, **tols) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32_matmul + torch.backends.cudnn.allow_tf32 = prev_tf32_cudnn @pytest.mark.parametrize("dtype", param_types) From 356d5498127489146246793605a7a84930182fcd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 20:27:46 +0000 Subject: [PATCH 12/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_numerics.py | 5 +- .../common/gemm/cublaslt_gemm.cu | 71 +++++++++++++------ .../common/swizzle/swizzle_block_scaling.cu | 3 +- transformer_engine/pytorch/csrc/quantizer.cpp | 18 ++--- .../pytorch/quantized_tensor.py | 4 +- .../pytorch/tensor/mxfp8_tensor.py | 7 +- .../pytorch/tensor/nvfp4_tensor.py | 7 +- transformer_engine/pytorch/utils.py | 1 - 8 files changed, 72 insertions(+), 44 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 109541a1cc..9e04c37d28 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1314,10 +1314,7 @@ def test_linear_tight_dims(recipe, inference, dtype): if recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip(f"{dtype} not supported for this NVFP4 recipe") - if ( - not inference - and recipe.fp4_quant_bwd_grad.random_hadamard_transform - ): + if not inference and recipe.fp4_quant_bwd_grad.random_hadamard_transform: # The reference baseline cannot express the Hadamard transform that # TE's wgrad GEMM applies to grad_output and input; their product is # grad^T · H^T · H · x which is only ≈ grad^T · x after quantization diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 70a5d572fa..7105c8fb2e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -168,8 +168,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { const auto align = min_alignment_elements(ret.Atype); NVTE_CHECK(ret.lda % align == 0, - "FP8 GEMM requires leading dimension of A to be divisible by ", align, "," - " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + "FP8 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else if (nvfp4) { @@ -190,8 +193,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const auto align = min_alignment_elements(ret.Atype); NVTE_CHECK((ret.lda % align) == 0, - "NVFP4 GEMM requires leading dimension of A to be divisible by ", align, "," - " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + "NVFP4 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe. @@ -211,8 +217,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const auto align = min_alignment_elements(ret.Atype); NVTE_CHECK((ret.lda % align) == 0, - "MXFP8 GEMM requires leading dimension of A to be divisible by ", align, "," - " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + "MXFP8 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling @@ -230,14 +239,19 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const auto align = min_alignment_elements(ret.Atype); NVTE_CHECK((ret.lda % align) == 0, - "Block-scaled FP8 GEMM requires leading dimension of A to be divisible by ", align, "," - " got lda=", ret.lda, " (m=", m, ", n=", n, ", k=", k, ")." + "Block-scaled FP8 GEMM requires leading dimension of A to be divisible by ", align, + "," + " got lda=", + ret.lda, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. NVTE_CHECK((m % 8) == 0, "Block-scaled FP8 GEMM requires m to be divisible by 8," - " got m=", m, " (n=", n, ", k=", k, ")." + " got m=", + m, " (n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else { NVTE_ERROR("A has unsupported scaling mode"); @@ -277,8 +291,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Btype)) { const auto align = min_alignment_elements(ret.Btype); NVTE_CHECK(ret.ldb % align == 0, - "FP8 GEMM requires leading dimension of B to be divisible by ", align, "," - " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + "FP8 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else if (nvfp4) { @@ -297,8 +314,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const auto align = min_alignment_elements(ret.Btype); NVTE_CHECK((ret.ldb % align) == 0, - "NVFP4 GEMM requires leading dimension of B to be divisible by ", align, "," - " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + "NVFP4 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (mxfp8) { if (is_B_transposed) { @@ -314,8 +334,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const auto align = min_alignment_elements(ret.Btype); NVTE_CHECK((ret.ldb % align) == 0, - "MXFP8 GEMM requires leading dimension of B to be divisible by ", align, "," - " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + "MXFP8 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling @@ -333,13 +356,18 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const auto align = min_alignment_elements(ret.Btype); NVTE_CHECK((ret.ldb % align) == 0, - "Block-scaled FP8 GEMM requires leading dimension of B to be divisible by ", align, "," - " got ldb=", ret.ldb, " (m=", m, ", n=", n, ", k=", k, ")." + "Block-scaled FP8 GEMM requires leading dimension of B to be divisible by ", align, + "," + " got ldb=", + ret.ldb, " (m=", m, ", n=", n, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { NVTE_CHECK((n % 8) == 0, "Block-scaled FP8 GEMM requires n to be divisible by 8 for 1D block scaling," - " got n=", n, " (m=", m, ", k=", k, ")." + " got n=", + n, " (m=", m, ", k=", k, + ")." " See https://docs.nvidia.com/cuda/cublas/#tensor-core-usage."); } } else { @@ -815,9 +843,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } NVTE_CHECK(returnedResults != 0, "Unable to find suitable cuBLAS GEMM algorithm" - " (m=", m, ", n=", n, ", k=", k, - ", A.scaling_mode=", to_string(inputA->scaling_mode), - ", B.scaling_mode=", to_string(inputB->scaling_mode), ")." + " (m=", + m, ", n=", n, ", k=", k, ", A.scaling_mode=", to_string(inputA->scaling_mode), + ", B.scaling_mode=", to_string(inputB->scaling_mode), + ")." " This may be caused by unsupported tensor dimensions for the" " current quantization recipe." " Set CUBLASLT_LOG_LEVEL=5 to get the exact reason from cuBLAS." diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index cbdbb232da..59bdeb0e01 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -143,7 +143,8 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); NVTE_CHECK(data_rows % 4 == 0, "Block scaling swizzle requires data_rows to be divisible by 4," - " got data_rows=", data_rows); + " got data_rows=", + data_rows); const uint32_t tiles_x = DIVUP(data_cols, 128u); const uint32_t tiles_y = DIVUP(data_rows, 128u); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 5089427314..4f88c19734 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1348,8 +1348,8 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; NVTE_CHECK(flat_last_dim % 16 == 0, "MXFP8 requires the last tensor dimension to be divisible by 16," - " got tensor with shape ", shape, - " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); + " got tensor with shape ", + shape, " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1668,7 +1668,8 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s NVTE_CHECK(last_dim % 16 == 0, "MXFP8 requires the last tensor dimension to be divisible by 16," - " got tensor with shape ", shape); + " got tensor with shape ", + shape); std::vector scale_shape; @@ -1742,14 +1743,14 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; NVTE_CHECK(flat_last_dim % 32 == 0, "NVFP4 requires the last tensor dimension to be divisible by 32," - " got tensor with shape ", shape, - " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); + " got tensor with shape ", + shape, " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); if (this->with_rht) { NVTE_CHECK(flat_first_dim % 16 == 0, "NVFP4 with random Hadamard transform requires the" " product of all dimensions except the last to be divisible by 16," - " got tensor with shape ", shape, - " (flat_first_dim=", flat_first_dim, ")"); + " got tensor with shape ", + shape, " (flat_first_dim=", flat_first_dim, ")"); } const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -2447,7 +2448,8 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s NVTE_CHECK(last_dim % 32 == 0, "NVFP4 requires the last tensor dimension to be divisible by 32," - " got tensor with shape ", shape); + " got tensor with shape ", + shape); std::vector scale_shape; diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 7c649f7ca0..fc50959385 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -363,7 +363,9 @@ def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False - def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument + def supports_quantized_allgather( + self, inp: torch.Tensor + ) -> bool: # pylint: disable=unused-argument """Whether tensor shape supports quantized all-gather. When False, the distributed all-gather falls back to gathering diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 534fe04580..668ec6f867 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -115,10 +115,9 @@ def make_empty( if device is None: device = torch.device("cuda") - assert shape[-1] % 16 == 0, ( - f"Incorrect shape {shape} for MXFP8." - f" Last dimension must be divisible by 16." - ) + assert ( + shape[-1] % 16 == 0 + ), f"Incorrect shape {shape} for MXFP8. Last dimension must be divisible by 16." # Allocate FP8 data data = None diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index cc85ba3f59..1db7fcceb2 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -308,10 +308,9 @@ def make_empty( if device is None: device = torch.device("cuda") - assert shape[-1] % 32 == 0, ( - f"Incorrect shape {shape} for NVFP4." - f" Last dimension must be divisible by 32." - ) + assert ( + shape[-1] % 32 == 0 + ), f"Incorrect shape {shape} for NVFP4. Last dimension must be divisible by 32." # Allocate FP4 data data = None diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index ac64752ac7..02237d8419 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -471,7 +471,6 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return tensor.to(dtype=dtype) - def is_bf16_compatible() -> bool: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher. From d8dbd916b2baa32aa68095057c4741975a4ca3f8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 16 Apr 2026 15:04:52 -0700 Subject: [PATCH 13/18] Use ceildiv for MXFP8 Python scale shapes The C++ quantizer uses DIVUP for partial trailing blocks, so a last_dim of 16 (now allowed by the relaxed C++ guard) produces one scale. The Python helpers were still floor-dividing, so the corresponding scale tensors collapsed to zero size. Mirrors the C++ behavior across make_empty, get_scale_shape, and the new_zeros torch_dispatch path; the new_zeros path also relaxes its fall-back guard so shapes the quantizer now accepts aren't silently downgraded to a generic tensor. Adds a Python-level test covering last_dim=16 since the end-to-end GEMM test can't exercise it (the MXFP8 GEMM kernel still requires K >= 32). Signed-off-by: Przemek Tredak --- tests/pytorch/test_numerics.py | 25 ++++++++++++ .../pytorch/tensor/mxfp8_tensor.py | 40 +++++++++++++------ 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 9e04c37d28..30c360d6ee 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1409,6 +1409,31 @@ def test_linear_tight_dims(recipe, inference, dtype): torch.backends.cudnn.allow_tf32 = prev_tf32_cudnn +def test_mxfp8_scale_shape_partial_block(): + """MXFP8 Python tensor helpers must ceildiv the scale dim, not floor-div. + + The C++ quantizer uses DIVUP so a 16-element trailing partial block gets + one scale; if the Python side floor-divides instead, scale tensors + collapse to zero size. This test exercises the newly-allowed last_dim=16 + path that the end-to-end GEMM test can't hit (the MXFP8 GEMM kernel + itself still requires K >= 32). + """ + if not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=True) + + rowwise = quantizer.get_scale_shape((32, 16), columnwise=False) + columnwise = quantizer.get_scale_shape((32, 16), columnwise=True) + assert all(d > 0 for d in rowwise), f"rowwise scale collapsed: {rowwise}" + assert all(d > 0 for d in columnwise), f"columnwise scale collapsed: {columnwise}" + + empty = quantizer.make_empty((32, 16), dtype=torch.bfloat16, device="cuda") + assert empty._rowwise_scale_inv.numel() > 0 + assert empty._columnwise_scale_inv.numel() > 0 + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["small"]) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 668ec6f867..4134ff7fc0 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -119,14 +119,18 @@ def make_empty( shape[-1] % 16 == 0 ), f"Incorrect shape {shape} for MXFP8. Last dimension must be divisible by 16." - # Allocate FP8 data + # Allocate FP8 data. Use ceil-division for the scaling dim so partial + # trailing blocks (allowed by the relaxed last_dim % 16 constraint) get + # one scale each, matching the C++ quantizer (DIVUP in quantizer.cpp). data = None scale_inv = None if self.rowwise_usage: data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) scale_inv = torch.empty( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple( + math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), 4 + ), dtype=torch.uint8, device=device, pin_memory=pin_memory, @@ -140,7 +144,9 @@ def make_empty( shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) columnwise_scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple( + math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), 4 + ), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, @@ -190,18 +196,23 @@ def get_scale_shape( Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout """ + # Use ceil-division (mirroring C++ DIVUP in quantizer.cpp) so partial + # trailing blocks from the relaxed last_dim % 16 constraint get one + # scale each instead of collapsing to zero. if columnwise: - # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # Columnwise: scale_inv shape is [ceildiv(prod(shape[:-1]), BLOCK_SIZE), shape[-1]] # with padding to multiples of [4, 128] return ( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple( + math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), 4 + ), round_up_to_nearest_multiple(shape[-1], 128), ) - # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # Rowwise: scale_inv shape is [prod(shape[:-1]), ceildiv(shape[-1], BLOCK_SIZE)] # with padding to multiples of [128, 4] return ( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), 4), ) def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: @@ -559,14 +570,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): shape = args[1] first_dim = math.prod(shape[:-1]) last_dim = shape[-1] - if ( - first_dim % MXFP8_BLOCK_SCALING_SIZE != 0 - or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0 - ): + # Fall back to high-precision for shapes the quantizer cannot + # represent. Only last_dim % 16 is required (matches the relaxed + # C++ quantizer); partial trailing blocks use ceildiv. + if last_dim % 16 != 0: return super().__torch_dispatch__(func, types, args, kwargs) - rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE] + rowwise_scale_inv_shape = [ + first_dim, + math.ceil(last_dim / MXFP8_BLOCK_SCALING_SIZE), + ] columnwise_scale_inv_shape = [ - first_dim // MXFP8_BLOCK_SCALING_SIZE, + math.ceil(first_dim / MXFP8_BLOCK_SCALING_SIZE), last_dim, ] if tensor._rowwise_data is not None: From 0cb7b5a981453860c5e9a1dcb69276722d4dcc42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:06:22 +0000 Subject: [PATCH 14/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 4134ff7fc0..ffdfc46894 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -128,9 +128,7 @@ def make_empty( data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) scale_inv = torch.empty( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple( - math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), 4 - ), + round_up_to_nearest_multiple(math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), 4), dtype=torch.uint8, device=device, pin_memory=pin_memory, From 22fe80a981069a114849c4ee9feb53aabc038c34 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 16 Apr 2026 15:33:01 -0700 Subject: [PATCH 15/18] Fix lint: undefined flat_first_dim + unused-argument - nvfp4_tensor.make_empty: flat_first_dim was never defined; use math.prod(shape[:-1]) (which is what the reader would expect from the surrounding code). - quantized_tensor.supports_quantized_allgather: move the unused-argument pragma to disable-next so pylint actually picks it up. Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/quantized_tensor.py | 5 ++--- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index fc50959385..59689d73c9 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -363,9 +363,8 @@ def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False - def supports_quantized_allgather( - self, inp: torch.Tensor - ) -> bool: # pylint: disable=unused-argument + # pylint: disable-next=unused-argument + def supports_quantized_allgather(self, inp: torch.Tensor) -> bool: """Whether tensor shape supports quantized all-gather. When False, the distributed all-gather falls back to gathering diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 1db7fcceb2..557420bbd9 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -337,7 +337,7 @@ def make_empty( if self.columnwise_usage: # enforce 2D shape to avoid [S, B, H] shape and B and be 1 # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - shape_2d = tuple([flat_first_dim, shape[-1]]) + shape_2d = (math.prod(shape[:-1]), shape[-1]) columnwise_data = torch.empty( self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), dtype=torch.uint8, From ddc0f2d77ac8d550a0b20cbed8829855cbfa86be Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 16 Apr 2026 16:05:57 -0700 Subject: [PATCH 16/18] Guarded FSDP2 against the MXFP8 tensors with the first dim not divisible by 32. Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index ffdfc46894..bdf2e561b8 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -662,15 +662,22 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m shape = self.shape if self._with_gemm_swizzled_scales: raise NotImplementedError( - "FSDP2 is only supported for MXFP8Tensors with compact scales" + "FSDP2 is only supported for MXFP8Tensors with compact scales." ) + flattened_in_shape0 = math.prod(shape[:-1]) if rowwise_scale_inv is not None: # Remove padding from rowwise scale_inv - flattened_in_shape0 = math.prod(shape[:-1]) if rowwise_scale_inv.size(0) != flattened_in_shape0: rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] if columnwise_scale_inv is not None: # Remove padding from columnwise scale_inv + if flattened_in_shape0 % MXFP8_BLOCK_SCALING_SIZE != 0: + raise NotImplementedError( + "FSDP2 during the backward pass is only supported for MXFP8Tensors with " + f"the flattened first dimension divisible by {MXFP8_BLOCK_SCALING_SIZE} " + f"and got tensor with shape {shape} with flattened first dimension " + f"{flattened_in_shape0}." + ) flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE if columnwise_scale_inv.size(0) != flattened_in_shape0: columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0] From 34567ae18907f535a510323cc22584a128a8aa7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 23:07:47 +0000 Subject: [PATCH 17/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index bdf2e561b8..2d6bf838e9 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -673,10 +673,10 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # Remove padding from columnwise scale_inv if flattened_in_shape0 % MXFP8_BLOCK_SCALING_SIZE != 0: raise NotImplementedError( - "FSDP2 during the backward pass is only supported for MXFP8Tensors with " - f"the flattened first dimension divisible by {MXFP8_BLOCK_SCALING_SIZE} " - f"and got tensor with shape {shape} with flattened first dimension " - f"{flattened_in_shape0}." + "FSDP2 during the backward pass is only supported for MXFP8Tensors with " + f"the flattened first dimension divisible by {MXFP8_BLOCK_SCALING_SIZE} " + f"and got tensor with shape {shape} with flattened first dimension " + f"{flattened_in_shape0}." ) flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE if columnwise_scale_inv.size(0) != flattened_in_shape0: From aa4863ac48391f855d4591d245a482234b057cd1 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 20 Apr 2026 14:50:19 -0700 Subject: [PATCH 18/18] Scope quantizer.cpp dim checks to quantization requirements Remove leading-dim alignment checks that duplicate the cuBLAS FP8 alignment requirement already enforced (with clearer error messages) in cublaslt_gemm.cu via min_alignment_elements(): MXFP8 last_dim%16 and NVFP4 last_dim%32 in both create_tensor and get_scale_shape. Keep quantization-specific checks: NVFP4 last_dim%2 for the 4-bit byte-packed storage, and first_dim%16 for NVFP4 with RHT (Hadamard transform). Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 4f88c19734..57cb416b3c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1346,10 +1346,6 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_last_dim % 16 == 0, - "MXFP8 requires the last tensor dimension to be divisible by 16," - " got tensor with shape ", - shape, " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1666,11 +1662,6 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); - NVTE_CHECK(last_dim % 16 == 0, - "MXFP8 requires the last tensor dimension to be divisible by 16," - " got tensor with shape ", - shape); - std::vector scale_shape; bool rowwise_usage = !columnwise; @@ -1741,8 +1732,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_last_dim % 32 == 0, - "NVFP4 requires the last tensor dimension to be divisible by 32," + NVTE_CHECK(flat_last_dim % 2 == 0, + "NVFP4 requires the last tensor dimension to be divisible by 2," " got tensor with shape ", shape, " (flat_first_dim=", flat_first_dim, ", flat_last_dim=", flat_last_dim, ")"); if (this->with_rht) { @@ -2446,11 +2437,6 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); auto flat_first_dim = numel / last_dim; - NVTE_CHECK(last_dim % 32 == 0, - "NVFP4 requires the last tensor dimension to be divisible by 32," - " got tensor with shape ", - shape); - std::vector scale_shape; bool rowwise_usage = !columnwise;