From ffe94e4ee505c9191c5fbd31f3c2c2c5a3161786 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 23 Apr 2026 14:53:24 -0700 Subject: [PATCH 1/5] Fix contiguous path for k=2880 Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/swizzle/swizzle.cu | 140 ++++++++++++++---- .../pytorch/csrc/extensions/swizzle.cpp | 38 ++++- 2 files changed, 143 insertions(+), 35 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index de4fdbb040..2a967baa34 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -132,19 +132,35 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, extern __shared__ int slm[]; // load, global -> regs + // Each register read for a given i is along the M direction at K-coord + // (bid_x * TB_DIM * SF_TILE_DIM_K + threadIdx.y * SF_TILE_DIM_K + i). When that + // K-coord is past original_K, the entire register is out of the per-tensor data + // region (which may be the unpadded compact extent), so we must NOT issue the + // __ldg there -- it could read past the per-tensor buffer (and, for the last + // tensor in a grouped allocation, past the end of the allocation entirely). LType regs_vec[N_SF_PER_TD_PER_TILE]; if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && threadIdx.y < k_tiles_in_tb) { + const int k_base = bid_x * TB_DIM * SF_TILE_DIM_K + threadIdx.y * SF_TILE_DIM_K; #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { const int thread_offset = (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + const int k_coord = k_base + i; + if (padding_k && k_coord >= original_K) { + // Entire register is past original_K: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); +#pragma unroll + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); - // Pad zeros - if (padding_m || padding_k) { + // Per-byte M masking is still needed when only part of the register is past + // original_M (i.e. K-coord is in range but the M position spans the boundary). + if (padding_m) { for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; - if (index / M >= original_K || index % M >= original_M) { + if (index % M >= original_M) { reinterpret_cast(regs_vec + i)[j] = 0; } } @@ -254,17 +270,32 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, extern __shared__ int4 slm_v4i[]; // load, global -> regs + // Each register read for a given i is along the K direction at row + // (bid_y * SF_TILE_DIM_M + i * TB_DIM + threadIdx.y). When that row is past + // original_M, the entire register is out of the per-tensor data region (which + // may be the unpadded compact extent), so we must NOT issue the __ldg there -- + // it could read past the per-tensor buffer (and, for the last tensor in a + // grouped allocation, past the end of the allocation entirely). LType regs_vec[N_SF_PER_TD_PER_TILE]; if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int row = bid_y * SF_TILE_DIM_M + i * TB_DIM + threadIdx.y; const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + if (padding_m && row >= original_M) { + // Entire register is past original_M: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); +#pragma unroll + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); - if (padding_m || padding_k) { - // Pad zeros + // Per-byte K masking is still needed when only part of the register is past + // original_K (i.e. row is in range but the K position spans the boundary). + if (padding_k) { for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; - if (index / K >= original_M || index % K >= original_K) { + if (index % K >= original_K) { reinterpret_cast(regs_vec + i)[j] = 0; } } @@ -628,11 +659,16 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) grouped_swizzle_row_scaling_uniform_shape_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, - const size_t scale_stride_bytes) { + const size_t input_stride_bytes, + const size_t output_stride_bytes) { const int tensor_id = blockIdx.z; + // Input and output strides may differ: input is in the kernel-produced "compact" + // layout (per-tensor stride = original_M * padded_k * elem_size) when callers + // pass the unswizzled grouped scale buffer as-is, while the output is always in + // the per-tensor padded ("swizzle-ready") layout (padded_m * padded_k * elem_size). const uint8_t* input_base = - reinterpret_cast(input) + tensor_id * scale_stride_bytes; - uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; + reinterpret_cast(input) + tensor_id * input_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; swizzle_row_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); @@ -643,11 +679,15 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) grouped_swizzle_col_scaling_uniform_shape_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, - const size_t scale_stride_bytes) { + const size_t input_stride_bytes, + const size_t output_stride_bytes) { const int tensor_id = blockIdx.z; + // See the rowwise kernel for stride semantics. For columnwise the per-tensor + // compact stride is DIVUP(original_K, 1) * padded_m * elem_size (i.e. the + // unpadded scale-row count in the K direction times the padded M extent). const uint8_t* input_base = - reinterpret_cast(input) + tensor_id * scale_stride_bytes; - uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; + reinterpret_cast(input) + tensor_id * input_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; swizzle_col_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); @@ -1924,23 +1964,59 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* const size_t padded_m = round_up_to_multiple(m, 128); const size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - const size_t scale_elems = padded_m * padded_k; + // Per-tensor scale-element counts: + // - "padded" layout: each tensor occupies padded_m * padded_k elements + // (total buffer = num_tensors * padded_m * padded_k). + // - "compact" layout (what the grouped MXFP8 quantize kernel actually writes): + // per-tensor stride is m * padded_k (rowwise) or DIVUP(k,32) * padded_m + // (columnwise) and the total buffer the C++ allocator hands out has its + // grouped first dim padded up to a multiple of 128 (rowwise) or 4 + // (columnwise) -- so the buffer may be slightly larger than + // num_tensors * compact_scale_elems, with trailing alignment slack at + // the very end (never read because of the per-tensor row/k guard in the + // kernel impl). + // The output is always written in the padded layout. The input may be in + // either layout; the kernel handles the compact case safely by using + // different per-tensor strides for input vs output and skipping loads past + // the per-tensor extent. + const size_t padded_scale_elems = padded_m * padded_k; + const size_t compact_scale_elems = + rowwise ? m * padded_k + : DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)) * padded_m; + const size_t compact_total_scale_elems = + rowwise + ? round_up_to_multiple(input->num_tensors * m, 128) * padded_k + : round_up_to_multiple(input->num_tensors * DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), + 4) * + padded_m; const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); - const size_t scale_stride_bytes = scale_elems * scale_elem_size; - if (rowwise) { - NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input scale_inv size does not match expected packed size."); - NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output scale_inv size does not match expected packed size."); + const size_t input_scale_numel = + rowwise ? input->scale_inv.numel() : input->columnwise_scale_inv.numel(); + const size_t output_scale_numel = + rowwise ? output->scale_inv.numel() : output->columnwise_scale_inv.numel(); + + bool input_is_compact; + if (input_scale_numel == input->num_tensors * padded_scale_elems) { + input_is_compact = false; + } else if (input_scale_numel == compact_total_scale_elems) { + input_is_compact = true; } else { - NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input columnwise_scale_inv size does not match expected packed size."); - NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output columnwise_scale_inv size does not match expected packed size."); + NVTE_ERROR( + "Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected packed size (got ", input_scale_numel, + ", expected either ", input->num_tensors * padded_scale_elems, " (per-tensor padded) or ", + compact_total_scale_elems, " (compact))."); } + NVTE_CHECK(output_scale_numel == input->num_tensors * padded_scale_elems, + "Grouped output ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected per-tensor padded size."); + + const size_t input_stride_bytes = + (input_is_compact ? compact_scale_elems : padded_scale_elems) * scale_elem_size; + const size_t output_stride_bytes = padded_scale_elems * scale_elem_size; const int num_tiles_m = padded_m / SF_TILE_DIM_M; const int num_tiles_k = padded_k / SF_TILE_DIM_K; @@ -1971,7 +2047,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_row_scaling_uniform_shape_kernel <<>>(input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - scale_stride_bytes); + input_stride_bytes, + output_stride_bytes); break; case 2: NVTE_CHECK_CUDA(cudaFuncSetAttribute( @@ -1980,7 +2057,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_row_scaling_uniform_shape_kernel <<>>(input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - scale_stride_bytes); + input_stride_bytes, + output_stride_bytes); break; case 1: NVTE_CHECK_CUDA(cudaFuncSetAttribute( @@ -1989,7 +2067,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_row_scaling_uniform_shape_kernel <<>>(input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - scale_stride_bytes); + input_stride_bytes, + output_stride_bytes); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -2003,7 +2082,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_col_scaling_uniform_shape_kernel <<>>(input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - scale_stride_bytes); + input_stride_bytes, + output_stride_bytes); break; case 2: NVTE_CHECK_CUDA(cudaFuncSetAttribute( @@ -2012,7 +2092,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_col_scaling_uniform_shape_kernel <<>>(input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - scale_stride_bytes); + input_stride_bytes, + output_stride_bytes); break; case 1: NVTE_CHECK_CUDA(cudaFuncSetAttribute( @@ -2021,7 +2102,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_col_scaling_uniform_shape_kernel <<>>(input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - scale_stride_bytes); + input_stride_bytes, + output_stride_bytes); break; default: NVTE_ERROR("Not valid vec_load_size."); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index cbaabaad17..d8ab830c48 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -403,16 +403,39 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), tensor_offsets.shape); } + // Per-tensor logical dimensions (uniform-shape grouped tensor). + const size_t num_tensors = input.num_tensors(); + const auto logical_shape_nvte = input.logical_shape(); + NVTE_CHECK(logical_shape_nvte.ndim >= 2, + "Grouped GEMM swizzle expects logical_shape with ndim >= 2."); + const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; + const size_t per_tensor_last_dim = logical_shape_nvte.data[logical_shape_nvte.ndim - 1]; + constexpr size_t kMxfp8BlockSize = 32; + + // Output is always allocated in the per-tensor padded ("swizzle-ready") layout + // so the cuDNN grouped GEMM consumer sees the correct stride between experts. + // The swizzle kernel itself handles converting from the kernel-emitted compact + // layout (per-tensor first dim is the unpadded value) to this padded layout. + auto compute_padded_grouped_scale_shape = [&](bool rowwise) { + const size_t m = rowwise ? per_tensor_first_dim : per_tensor_last_dim; + const size_t k = rowwise ? per_tensor_last_dim : per_tensor_first_dim; + const size_t padded_m = ceildiv(m, size_t{128}) * 128; + const size_t padded_k = ceildiv(ceildiv(k, kMxfp8BlockSize), size_t{4}) * 4; + return std::vector{num_tensors * padded_m, padded_k}; + }; + if (swizzle_rowwise) { const auto data = input.get_rowwise_data(); const auto data_dtype = static_cast(data.dtype); const auto scales_dtype = static_cast(row_scales.dtype); swizzle_input.set_rowwise_data(nullptr, data_dtype, data.shape); swizzle_input.set_rowwise_scale_inv(row_scales.data_ptr, scales_dtype, row_scales.shape); - rowwise_scales_pyt = allocateSpace(row_scales.shape, scales_dtype, false); + const auto padded_shape = compute_padded_grouped_scale_shape(/*rowwise=*/true); + rowwise_scales_pyt = allocateSpace(padded_shape, scales_dtype, false); + NVTEShape padded_shape_nvte = nvte_make_shape(padded_shape.data(), padded_shape.size()); swizzle_output.set_rowwise_data(nullptr, data_dtype, data.shape); swizzle_output.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, - row_scales.shape); + padded_shape_nvte); } if (swizzle_columnwise) { const auto data = input.get_columnwise_data(); @@ -420,10 +443,12 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW const auto scales_dtype = static_cast(col_scales.dtype); swizzle_input.set_columnwise_data(nullptr, data_dtype, data.shape); swizzle_input.set_columnwise_scale_inv(col_scales.data_ptr, scales_dtype, col_scales.shape); - columnwise_scales_pyt = allocateSpace(col_scales.shape, scales_dtype, false); + const auto padded_shape = compute_padded_grouped_scale_shape(/*rowwise=*/false); + columnwise_scales_pyt = allocateSpace(padded_shape, scales_dtype, false); + NVTEShape padded_shape_nvte = nvte_make_shape(padded_shape.data(), padded_shape.size()); swizzle_output.set_columnwise_data(nullptr, data_dtype, data.shape); swizzle_output.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, - col_scales.shape); + padded_shape_nvte); } swizzle_output.set_with_gemm_swizzled_scales(true); @@ -434,12 +459,13 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW if (swizzle_rowwise) { const auto scales_dtype = static_cast(row_scales.dtype); - input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, row_scales.shape); + input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, + getTensorShape(*rowwise_scales_pyt)); } if (swizzle_columnwise) { const auto scales_dtype = static_cast(col_scales.dtype); input.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, - col_scales.shape); + getTensorShape(*columnwise_scales_pyt)); } input.set_with_gemm_swizzled_scales(true); return SwizzledGroupedScales{std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; From 5de0389f8c6b662a647cd1c0dc37178dfb57633c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 23 Apr 2026 23:56:38 +0000 Subject: [PATCH 2/5] format Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/swizzle/swizzle.cu | 67 +++++++++----------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 2a967baa34..f879ecea63 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1981,14 +1981,12 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* // the per-tensor extent. const size_t padded_scale_elems = padded_m * padded_k; const size_t compact_scale_elems = - rowwise ? m * padded_k - : DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)) * padded_m; + rowwise ? m * padded_k : DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)) * padded_m; const size_t compact_total_scale_elems = - rowwise - ? round_up_to_multiple(input->num_tensors * m, 128) * padded_k - : round_up_to_multiple(input->num_tensors * DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), - 4) * - padded_m; + rowwise ? round_up_to_multiple(input->num_tensors * m, 128) * padded_k + : round_up_to_multiple( + input->num_tensors * DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4) * + padded_m; const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); @@ -2004,14 +2002,13 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* } else if (input_scale_numel == compact_total_scale_elems) { input_is_compact = true; } else { - NVTE_ERROR( - "Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), - " size does not match expected packed size (got ", input_scale_numel, - ", expected either ", input->num_tensors * padded_scale_elems, " (per-tensor padded) or ", - compact_total_scale_elems, " (compact))."); + NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected packed size (got ", input_scale_numel, + ", expected either ", input->num_tensors * padded_scale_elems, + " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); } - NVTE_CHECK(output_scale_numel == input->num_tensors * padded_scale_elems, - "Grouped output ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), + NVTE_CHECK(output_scale_numel == input->num_tensors * padded_scale_elems, "Grouped output ", + (rowwise ? "scale_inv" : "columnwise_scale_inv"), " size does not match expected per-tensor padded size."); const size_t input_stride_bytes = @@ -2045,30 +2042,27 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_row_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, - output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); break; case 2: NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_row_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, - output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); break; case 1: NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_row_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, - output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -2080,30 +2074,27 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_col_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, - output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); break; case 2: NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_col_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, - output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); break; case 1: NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_col_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, - output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); break; default: NVTE_ERROR("Not valid vec_load_size."); From bccbf6ab3a71884b34a1b81063341f3268bb530c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 24 Apr 2026 19:46:06 +0000 Subject: [PATCH 3/5] Review suggestion from @Oleg-Goncharov Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/swizzle/swizzle.cu | 142 ++++++++++++++----- 1 file changed, 110 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index f879ecea63..85fbe73f00 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -91,7 +91,11 @@ __device__ inline void regs_unshuffle_with_bit_shifts(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; } -template +// IS_PADDED_K / IS_PADDED_M select the boundary-block specialization at compile +// time so the inner load loop avoids the per-iteration runtime checks. The +// caller computes the runtime predicates from blockIdx/gridDim once per block +// (uniform across the block) and dispatches to the right specialization. +template __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, const int bid_x, @@ -117,9 +121,6 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; } - bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); - bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); - const int input_offset = bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; const int32_t* input_i32 = reinterpret_cast(input) + input_offset; @@ -147,17 +148,19 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int thread_offset = (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; const int k_coord = k_base + i; - if (padding_k && k_coord >= original_K) { - // Entire register is past original_K: zero directly without loading. - uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); + if constexpr (IS_PADDED_K) { + if (k_coord >= original_K) { + // Entire register is past original_K: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); #pragma unroll - for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; - continue; + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); // Per-byte M masking is still needed when only part of the register is past // original_M (i.e. K-coord is in range but the M position spans the boundary). - if (padding_m) { + if constexpr (IS_PADDED_M) { for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; if (index % M >= original_M) { @@ -199,12 +202,43 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, } } +// Dispatch helper: pick the right (IS_PADDED_K, IS_PADDED_M) col-scaling impl +// specialization at runtime based on the per-block padding predicates. The +// branching here is uniform across all threads in the block, so the indirect +// path each block takes still inlines cleanly. +template +__device__ __forceinline__ void dispatch_swizzle_col_scaling_kernel_impl( + const void* input, void* output, const int M, const int K, const int original_M, + const int original_K, const int bid_x, const int bid_y, const int grid_dim_x, + const int grid_dim_y, const bool padding_k, const bool padding_m) { + if (padding_k && padding_m) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_k) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_m) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + template __global__ void __launch_bounds__(TB_DIM* TB_DIM) swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K) { - swizzle_col_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y, + padding_k, padding_m); } template @@ -240,7 +274,11 @@ __device__ inline void regs_unshuffle(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; } -template +// IS_PADDED_K / IS_PADDED_M select the boundary-block specialization at compile +// time so the inner load loop avoids the per-iteration runtime checks. The +// caller computes the runtime predicates from blockIdx/gridDim once per block +// (uniform across the block) and dispatches to the right specialization. +template __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, const int bid_x, @@ -259,9 +297,6 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); - bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); - const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; const int* input_i32 = reinterpret_cast(input) + input_offset; int* output_i32 = reinterpret_cast(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 + @@ -282,17 +317,19 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { const int row = bid_y * SF_TILE_DIM_M + i * TB_DIM + threadIdx.y; const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; - if (padding_m && row >= original_M) { - // Entire register is past original_M: zero directly without loading. - uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); + if constexpr (IS_PADDED_M) { + if (row >= original_M) { + // Entire register is past original_M: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); #pragma unroll - for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; - continue; + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); // Per-byte K masking is still needed when only part of the register is past // original_K (i.e. row is in range but the K position spans the boundary). - if (padding_k) { + if constexpr (IS_PADDED_K) { for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; if (index % K >= original_K) { @@ -324,12 +361,43 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, } } +// Dispatch helper: pick the right (IS_PADDED_K, IS_PADDED_M) row-scaling impl +// specialization at runtime based on the per-block padding predicates. The +// branching here is uniform across all threads in the block, so the indirect +// path each block takes still inlines cleanly. +template +__device__ __forceinline__ void dispatch_swizzle_row_scaling_kernel_impl( + const void* input, void* output, const int M, const int K, const int original_M, + const int original_K, const int bid_x, const int bid_y, const int grid_dim_x, + const int grid_dim_y, const bool padding_k, const bool padding_m) { + if (padding_k && padding_m) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_k) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_m) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + template __global__ void __launch_bounds__(TB_DIM* TB_DIM) swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K) { - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y, + padding_k, padding_m); } // Narrow-K specialization for row scaling swizzle. @@ -669,9 +737,11 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) const uint8_t* input_base = reinterpret_cast(input) + tensor_id * input_stride_bytes; uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; - swizzle_row_scaling_kernel_impl( + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, - gridDim.y); + gridDim.y, padding_k, padding_m); } template @@ -688,9 +758,11 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) const uint8_t* input_base = reinterpret_cast(input) + tensor_id * input_stride_bytes; uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; - swizzle_col_scaling_kernel_impl( + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, - gridDim.y); + gridDim.y, padding_k, padding_m); } template @@ -791,8 +863,11 @@ __global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_ const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + const bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + const bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y, padding_k, + padding_m); } template @@ -821,8 +896,11 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; - swizzle_col_scaling_kernel_impl( - input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + const bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + const bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y, padding_k, + padding_m); } template From 791704260f7c3d88b317039cd4f34f6f998fd8cc Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 27 Apr 2026 21:28:49 +0000 Subject: [PATCH 4/5] Add test for swizzle + padding fusion Signed-off-by: Kirthi Shankar Sivamani --- tests/cpp/operator/test_swizzle.cu | 248 +++++++++++++++++++++++++++++ 1 file changed, 248 insertions(+) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 1ea82f19cd..6ea951c73c 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -541,6 +541,254 @@ INSTANTIATE_TEST_SUITE_P( } ); +// Build a "compact" grouped MXFP8 scale_inv buffer for swizzle input. This is +// the layout produced by the grouped MXFP8 quantize kernel: the per-tensor +// stride is `M_per_tensor * padded_K` (rowwise) or `DIVUP(M,32) * padded_K_for_cols` +// (columnwise) -- i.e. NO per-tensor padding rows are inserted. The total buffer +// is rounded up at its very end to a multiple of 128 (rowwise) or 4 (columnwise) +// in the grouped first dim, matching what the C++ allocator hands out. +// +// Each tensor's compact scales are gathered from the unpadded-prefix rows of +// that tensor's per-tensor padded CPU scale buffer. +namespace { + +struct CompactScaleBuffer { + test::CudaPtr<> ptr; + size_t numel{0}; +}; + +CompactScaleBuffer gather_compact_grouped_scale( + const std::vector>& tensors, + size_t M_per_tensor, size_t K_per_tensor, bool rowwise) { + using namespace test; + constexpr size_t BLOCK = 32; + const size_t num_tensors = tensors.size(); + + size_t per_tensor_first_unpadded; + size_t per_tensor_last_padded; + size_t group_first_align; + if (rowwise) { + per_tensor_first_unpadded = M_per_tensor; + const size_t scale_K = (K_per_tensor + BLOCK - 1) / BLOCK; + per_tensor_last_padded = ((scale_K + 4 - 1) / 4) * 4; + group_first_align = 128; + } else { + per_tensor_first_unpadded = (M_per_tensor + BLOCK - 1) / BLOCK; + per_tensor_last_padded = ((K_per_tensor + 128 - 1) / 128) * 128; + group_first_align = 4; + } + + const size_t per_tensor_compact_numel = + per_tensor_first_unpadded * per_tensor_last_padded; + const size_t total_first = + ((num_tensors * per_tensor_first_unpadded + group_first_align - 1) + / group_first_align) * group_first_align; + const size_t total_numel = total_first * per_tensor_last_padded; + + std::vector host_buf(total_numel, 0); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + const NVTEShape padded_shape = rowwise ? tensors[i]->rowwise_scale_inv_shape() + : tensors[i]->columnwise_scale_inv_shape(); + NVTE_CHECK(padded_shape.data[1] == per_tensor_last_padded, + "Unexpected per-tensor padded last dim in compact gather."); + const uint8_t* src = rowwise + ? tensors[i]->rowwise_cpu_scale_inv_ptr() + : tensors[i]->columnwise_cpu_scale_inv_ptr(); + uint8_t* dst = host_buf.data() + i * per_tensor_compact_numel; + // Per-tensor padded buffer is row-major (padded_first, padded_last); copy + // only the first `per_tensor_first_unpadded` rows. + std::memcpy(dst, src, per_tensor_compact_numel); + } + + CompactScaleBuffer out; + out.ptr = cuda_alloc(total_numel); + NVTE_CHECK_CUDA(cudaMemcpy(out.ptr.get(), host_buf.data(), + total_numel, cudaMemcpyHostToDevice)); + out.numel = total_numel; + return out; +} + +} // namespace + +// Tests that grouped_swizzle_for_gemm correctly handles a COMPACT input +// scale_inv buffer (no per-tensor padding rows), producing an output in the +// per-tensor padded layout with padded regions zeroed out. This is the layout +// produced by the grouped MXFP8 quantize kernel; previously the swizzle kernel +// asserted the input matched the per-tensor padded packed size, which broke +// grouped MLP weights with M not a multiple of 128. +void performTestGroupedSwizzleMXFP8CompactInput(const int num_tensors, const size_t M, + const size_t K) { + using namespace transformer_engine; + using namespace test; + + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs, output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + for (int i = 0; i < num_tensors; ++i) { + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + + // Zero the per-tensor padded regions so the reference (which sees the + // padded layout) and the kernel (which sees the compact layout but writes + // zeros into output padding) agree byte-for-byte. + input->to_cpu(); + const NVTEShape rs = input->rowwise_scale_inv_shape(); + zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + const NVTEShape cs = input->columnwise_scale_inv_shape(); + zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + input->from_cpu(); + + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + // Build a per-tensor padded grouped output via the standard helper, and a + // compact-scale grouped input by overriding the scale_inv buffers of a + // padded grouped input with newly allocated compact buffers. + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + + CompactScaleBuffer compact_row = + gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/true); + CompactScaleBuffer compact_col = + gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/false); + + grouped_input.scale_inv = std::move(compact_row.ptr); + grouped_input.columnwise_scale_inv = std::move(compact_col.ptr); + { + NVTEShape s = nvte_make_shape(&compact_row.numel, 1); + NVTEBasicTensor t{grouped_input.scale_inv.get(), kNVTEFloat8E8M0, s}; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedRowwiseScaleInv, &t, sizeof(t)); + } + { + NVTEShape s = nvte_make_shape(&compact_col.numel, 1); + NVTEBasicTensor t{grouped_input.columnwise_scale_inv.get(), kNVTEFloat8E8M0, s}; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedColumnwiseScaleInv, &t, sizeof(t)); + } + + const uint8_t input_swizzled = 0; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &input_swizzled, sizeof(input_swizzled)); + const uint8_t output_swizzled = 1; + nvte_set_grouped_tensor_param(grouped_output.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &output_swizzled, sizeof(output_swizzled)); + + const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + // Memset to a non-zero sentinel so we can detect kernel failures to write + // padded regions (those must be overwritten with zero by the kernel). + NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0xCD, + num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0xCD, + num_tensors * col_numel)); + + nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + std::vector output_row(num_tensors * row_numel); + std::vector output_col(num_tensors * col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(), + output_row.size(), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), + grouped_output.columnwise_scale_inv.get(), + output_col.size(), cudaMemcpyDeviceToHost)); + + std::vector ref_row(num_tensors * row_numel); + std::vector ref_col(num_tensors * col_numel); + for (int i = 0; i < num_tensors; ++i) { + compute_ref_swizzle<128, 4, true>( + input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data() + i * row_numel, + row_shape.data[0], row_shape.data[1]); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data() + i * col_numel, + col_shape.data[1], col_shape.data[0]); + } + + compareResults("grouped_swizzle_compact_rowwise", output_row.data(), + ref_row.data(), num_tensors * row_numel); + compareResults("grouped_swizzle_compact_colwise", output_col.data(), + ref_col.data(), num_tensors * col_numel); +} + +class SwizzleGroupedCompactInputTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(SwizzleGroupedCompactInputTestSuite, TestGroupedSwizzleMXFP8CompactInput) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + performTestGroupedSwizzleMXFP8CompactInput(num_tensors, M, K); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleGroupedCompactInputTestSuite, + ::testing::Values( + // Aligned M and K. Per-tensor compact stride == per-tensor padded stride, + // so the kernel may use either layout; serves as a sanity check that the + // compact-input plumbing doesn't regress aligned shapes. + std::make_tuple(3, 256, 256), + std::make_tuple(4, 128, 128), + // M NOT divisible by 128 (the original-bug case): per-tensor compact stride + // shrinks vs padded. We pick (num_tensors, M) so that BOTH + // round_up(N * M, 128) != N * round_up(M, 128) (rowwise) + // round_up(N * DIVUP(M,32), 4) != N * round_up(DIVUP(M,32),4) (colwise) + // i.e. compact_total != padded_total on either axis, so the kernel + // unambiguously detects the compact layout. + std::make_tuple(4, 200, 256), + std::make_tuple(4, 65, 256), + std::make_tuple(2, 2880, 2880), // shape from the originally failing workload + // K not divisible by 128 (DIVUP(K,32) padded up to a multiple of 4). + std::make_tuple(3, 256, 160), + std::make_tuple(2, 256, 96), + // Neither M nor K aligned. + std::make_tuple(4, 200, 160), + std::make_tuple(4, 33, 64), + std::make_tuple(2, 1, 32), + // num_tensors * M not aligned to 128 -> exercises trailing alignment slack + // at the end of the compact rowwise buffer. + std::make_tuple(3, 64, 128), + std::make_tuple(5, 33, 96) + ), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)); + } +); + class UnswizzleGroupedTestSuite : public ::testing::TestWithParam> {}; From 731f4e41ea69f3bea6f58eddf616a28be94a1aea Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 28 Apr 2026 23:13:52 +0000 Subject: [PATCH 5/5] Address review comments Signed-off-by: Kirthi Shankar Sivamani --- tests/cpp/operator/test_swizzle.cu | 25 +++--- transformer_engine/common/common.h | 20 +++++ transformer_engine/common/swizzle/swizzle.cu | 81 +++++--------------- 3 files changed, 51 insertions(+), 75 deletions(-) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 6ea951c73c..3fec5062ff 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -248,11 +248,11 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const const NVTEShape rs = input->rowwise_scale_inv_shape(); zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), rs.data[0], rs.data[1], - M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + M, divide_round_up(K, BLOCK_SIZE)); const NVTEShape cs = input->columnwise_scale_inv_shape(); zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), cs.data[0], cs.data[1], - (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + divide_round_up(M, BLOCK_SIZE), K); input->from_cpu(); input_ptrs.push_back(input.get()); @@ -444,11 +444,11 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si const NVTEShape rs = orig->rowwise_scale_inv_shape(); zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr(), rs.data[0], rs.data[1], - M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + M, divide_round_up(K, BLOCK_SIZE)); const NVTEShape cs = orig->columnwise_scale_inv_shape(); zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr(), cs.data[0], cs.data[1], - (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + divide_round_up(M, BLOCK_SIZE), K); orig->from_cpu(); orig_ptrs.push_back(orig.get()); @@ -569,20 +569,19 @@ CompactScaleBuffer gather_compact_grouped_scale( size_t group_first_align; if (rowwise) { per_tensor_first_unpadded = M_per_tensor; - const size_t scale_K = (K_per_tensor + BLOCK - 1) / BLOCK; - per_tensor_last_padded = ((scale_K + 4 - 1) / 4) * 4; + per_tensor_last_padded = + round_up_to_nearest_multiple(divide_round_up(K_per_tensor, BLOCK), 4); group_first_align = 128; } else { - per_tensor_first_unpadded = (M_per_tensor + BLOCK - 1) / BLOCK; - per_tensor_last_padded = ((K_per_tensor + 128 - 1) / 128) * 128; + per_tensor_first_unpadded = divide_round_up(M_per_tensor, BLOCK); + per_tensor_last_padded = round_up_to_nearest_multiple(K_per_tensor, 128); group_first_align = 4; } const size_t per_tensor_compact_numel = per_tensor_first_unpadded * per_tensor_last_padded; - const size_t total_first = - ((num_tensors * per_tensor_first_unpadded + group_first_align - 1) - / group_first_align) * group_first_align; + const size_t total_first = round_up_to_nearest_multiple( + num_tensors * per_tensor_first_unpadded, group_first_align); const size_t total_numel = total_first * per_tensor_last_padded; std::vector host_buf(total_numel, 0); @@ -649,11 +648,11 @@ void performTestGroupedSwizzleMXFP8CompactInput(const int num_tensors, const siz const NVTEShape rs = input->rowwise_scale_inv_shape(); zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), rs.data[0], rs.data[1], - M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + M, divide_round_up(K, BLOCK_SIZE)); const NVTEShape cs = input->columnwise_scale_inv_shape(); zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), cs.data[0], cs.data[1], - (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + divide_round_up(M, BLOCK_SIZE), K); input->from_cpu(); input_ptrs.push_back(input.get()); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 68aa0f4c51..c1b3f8f427 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -946,6 +946,26 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(INTEGER_ELTS_NUM, type, ...) \ + switch (INTEGER_ELTS_NUM) { \ + case 1: { \ + using type = int; \ + { __VA_ARGS__ } \ + } break; \ + case 2: { \ + using type = int2; \ + { __VA_ARGS__ } \ + } break; \ + case 4: { \ + using type = int4; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported number of integer elements ", INTEGER_ELTS_NUM, \ + ". Expected one of: 1, 2, or 4."); \ + } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 85fbe73f00..ad4a130928 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -330,6 +330,7 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, // Per-byte K masking is still needed when only part of the register is past // original_K (i.e. row is in range but the K position spans the boundary). if constexpr (IS_PADDED_K) { +#pragma unroll for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; if (index % K >= original_K) { @@ -2114,69 +2115,25 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; if (rowwise) { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>( - input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>( - input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>( - input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } + TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(vec_load_size, LType, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); + }); } else { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>( - input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>( - input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>( - input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } + TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(vec_load_size, LType, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); + }); } NVTE_CHECK_CUDA(cudaGetLastError()); };