diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 1ea82f19cd..3fec5062ff 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -248,11 +248,11 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const const NVTEShape rs = input->rowwise_scale_inv_shape(); zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), rs.data[0], rs.data[1], - M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + M, divide_round_up(K, BLOCK_SIZE)); const NVTEShape cs = input->columnwise_scale_inv_shape(); zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), cs.data[0], cs.data[1], - (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + divide_round_up(M, BLOCK_SIZE), K); input->from_cpu(); input_ptrs.push_back(input.get()); @@ -444,11 +444,11 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si const NVTEShape rs = orig->rowwise_scale_inv_shape(); zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr(), rs.data[0], rs.data[1], - M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + M, divide_round_up(K, BLOCK_SIZE)); const NVTEShape cs = orig->columnwise_scale_inv_shape(); zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr(), cs.data[0], cs.data[1], - (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + divide_round_up(M, BLOCK_SIZE), K); orig->from_cpu(); orig_ptrs.push_back(orig.get()); @@ -541,6 +541,253 @@ INSTANTIATE_TEST_SUITE_P( } ); +// Build a "compact" grouped MXFP8 scale_inv buffer for swizzle input. This is +// the layout produced by the grouped MXFP8 quantize kernel: the per-tensor +// stride is `M_per_tensor * padded_K` (rowwise) or `DIVUP(M,32) * padded_K_for_cols` +// (columnwise) -- i.e. NO per-tensor padding rows are inserted. The total buffer +// is rounded up at its very end to a multiple of 128 (rowwise) or 4 (columnwise) +// in the grouped first dim, matching what the C++ allocator hands out. +// +// Each tensor's compact scales are gathered from the unpadded-prefix rows of +// that tensor's per-tensor padded CPU scale buffer. +namespace { + +struct CompactScaleBuffer { + test::CudaPtr<> ptr; + size_t numel{0}; +}; + +CompactScaleBuffer gather_compact_grouped_scale( + const std::vector>& tensors, + size_t M_per_tensor, size_t K_per_tensor, bool rowwise) { + using namespace test; + constexpr size_t BLOCK = 32; + const size_t num_tensors = tensors.size(); + + size_t per_tensor_first_unpadded; + size_t per_tensor_last_padded; + size_t group_first_align; + if (rowwise) { + per_tensor_first_unpadded = M_per_tensor; + per_tensor_last_padded = + round_up_to_nearest_multiple(divide_round_up(K_per_tensor, BLOCK), 4); + group_first_align = 128; + } else { + per_tensor_first_unpadded = divide_round_up(M_per_tensor, BLOCK); + per_tensor_last_padded = round_up_to_nearest_multiple(K_per_tensor, 128); + group_first_align = 4; + } + + const size_t per_tensor_compact_numel = + per_tensor_first_unpadded * per_tensor_last_padded; + const size_t total_first = round_up_to_nearest_multiple( + num_tensors * per_tensor_first_unpadded, group_first_align); + const size_t total_numel = total_first * per_tensor_last_padded; + + std::vector host_buf(total_numel, 0); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + const NVTEShape padded_shape = rowwise ? tensors[i]->rowwise_scale_inv_shape() + : tensors[i]->columnwise_scale_inv_shape(); + NVTE_CHECK(padded_shape.data[1] == per_tensor_last_padded, + "Unexpected per-tensor padded last dim in compact gather."); + const uint8_t* src = rowwise + ? tensors[i]->rowwise_cpu_scale_inv_ptr() + : tensors[i]->columnwise_cpu_scale_inv_ptr(); + uint8_t* dst = host_buf.data() + i * per_tensor_compact_numel; + // Per-tensor padded buffer is row-major (padded_first, padded_last); copy + // only the first `per_tensor_first_unpadded` rows. + std::memcpy(dst, src, per_tensor_compact_numel); + } + + CompactScaleBuffer out; + out.ptr = cuda_alloc(total_numel); + NVTE_CHECK_CUDA(cudaMemcpy(out.ptr.get(), host_buf.data(), + total_numel, cudaMemcpyHostToDevice)); + out.numel = total_numel; + return out; +} + +} // namespace + +// Tests that grouped_swizzle_for_gemm correctly handles a COMPACT input +// scale_inv buffer (no per-tensor padding rows), producing an output in the +// per-tensor padded layout with padded regions zeroed out. This is the layout +// produced by the grouped MXFP8 quantize kernel; previously the swizzle kernel +// asserted the input matched the per-tensor padded packed size, which broke +// grouped MLP weights with M not a multiple of 128. +void performTestGroupedSwizzleMXFP8CompactInput(const int num_tensors, const size_t M, + const size_t K) { + using namespace transformer_engine; + using namespace test; + + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs, output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + for (int i = 0; i < num_tensors; ++i) { + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + + // Zero the per-tensor padded regions so the reference (which sees the + // padded layout) and the kernel (which sees the compact layout but writes + // zeros into output padding) agree byte-for-byte. + input->to_cpu(); + const NVTEShape rs = input->rowwise_scale_inv_shape(); + zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, divide_round_up(K, BLOCK_SIZE)); + const NVTEShape cs = input->columnwise_scale_inv_shape(); + zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + divide_round_up(M, BLOCK_SIZE), K); + input->from_cpu(); + + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + // Build a per-tensor padded grouped output via the standard helper, and a + // compact-scale grouped input by overriding the scale_inv buffers of a + // padded grouped input with newly allocated compact buffers. + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + + CompactScaleBuffer compact_row = + gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/true); + CompactScaleBuffer compact_col = + gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/false); + + grouped_input.scale_inv = std::move(compact_row.ptr); + grouped_input.columnwise_scale_inv = std::move(compact_col.ptr); + { + NVTEShape s = nvte_make_shape(&compact_row.numel, 1); + NVTEBasicTensor t{grouped_input.scale_inv.get(), kNVTEFloat8E8M0, s}; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedRowwiseScaleInv, &t, sizeof(t)); + } + { + NVTEShape s = nvte_make_shape(&compact_col.numel, 1); + NVTEBasicTensor t{grouped_input.columnwise_scale_inv.get(), kNVTEFloat8E8M0, s}; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedColumnwiseScaleInv, &t, sizeof(t)); + } + + const uint8_t input_swizzled = 0; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &input_swizzled, sizeof(input_swizzled)); + const uint8_t output_swizzled = 1; + nvte_set_grouped_tensor_param(grouped_output.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &output_swizzled, sizeof(output_swizzled)); + + const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + // Memset to a non-zero sentinel so we can detect kernel failures to write + // padded regions (those must be overwritten with zero by the kernel). + NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0xCD, + num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0xCD, + num_tensors * col_numel)); + + nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + std::vector output_row(num_tensors * row_numel); + std::vector output_col(num_tensors * col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(), + output_row.size(), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), + grouped_output.columnwise_scale_inv.get(), + output_col.size(), cudaMemcpyDeviceToHost)); + + std::vector ref_row(num_tensors * row_numel); + std::vector ref_col(num_tensors * col_numel); + for (int i = 0; i < num_tensors; ++i) { + compute_ref_swizzle<128, 4, true>( + input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data() + i * row_numel, + row_shape.data[0], row_shape.data[1]); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data() + i * col_numel, + col_shape.data[1], col_shape.data[0]); + } + + compareResults("grouped_swizzle_compact_rowwise", output_row.data(), + ref_row.data(), num_tensors * row_numel); + compareResults("grouped_swizzle_compact_colwise", output_col.data(), + ref_col.data(), num_tensors * col_numel); +} + +class SwizzleGroupedCompactInputTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(SwizzleGroupedCompactInputTestSuite, TestGroupedSwizzleMXFP8CompactInput) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + performTestGroupedSwizzleMXFP8CompactInput(num_tensors, M, K); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleGroupedCompactInputTestSuite, + ::testing::Values( + // Aligned M and K. Per-tensor compact stride == per-tensor padded stride, + // so the kernel may use either layout; serves as a sanity check that the + // compact-input plumbing doesn't regress aligned shapes. + std::make_tuple(3, 256, 256), + std::make_tuple(4, 128, 128), + // M NOT divisible by 128 (the original-bug case): per-tensor compact stride + // shrinks vs padded. We pick (num_tensors, M) so that BOTH + // round_up(N * M, 128) != N * round_up(M, 128) (rowwise) + // round_up(N * DIVUP(M,32), 4) != N * round_up(DIVUP(M,32),4) (colwise) + // i.e. compact_total != padded_total on either axis, so the kernel + // unambiguously detects the compact layout. + std::make_tuple(4, 200, 256), + std::make_tuple(4, 65, 256), + std::make_tuple(2, 2880, 2880), // shape from the originally failing workload + // K not divisible by 128 (DIVUP(K,32) padded up to a multiple of 4). + std::make_tuple(3, 256, 160), + std::make_tuple(2, 256, 96), + // Neither M nor K aligned. + std::make_tuple(4, 200, 160), + std::make_tuple(4, 33, 64), + std::make_tuple(2, 1, 32), + // num_tensors * M not aligned to 128 -> exercises trailing alignment slack + // at the end of the compact rowwise buffer. + std::make_tuple(3, 64, 128), + std::make_tuple(5, 33, 96) + ), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)); + } +); + class UnswizzleGroupedTestSuite : public ::testing::TestWithParam> {}; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 68aa0f4c51..c1b3f8f427 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -946,6 +946,26 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(INTEGER_ELTS_NUM, type, ...) \ + switch (INTEGER_ELTS_NUM) { \ + case 1: { \ + using type = int; \ + { __VA_ARGS__ } \ + } break; \ + case 2: { \ + using type = int2; \ + { __VA_ARGS__ } \ + } break; \ + case 4: { \ + using type = int4; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported number of integer elements ", INTEGER_ELTS_NUM, \ + ". Expected one of: 1, 2, or 4."); \ + } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index de4fdbb040..ad4a130928 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -91,7 +91,11 @@ __device__ inline void regs_unshuffle_with_bit_shifts(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; } -template +// IS_PADDED_K / IS_PADDED_M select the boundary-block specialization at compile +// time so the inner load loop avoids the per-iteration runtime checks. The +// caller computes the runtime predicates from blockIdx/gridDim once per block +// (uniform across the block) and dispatches to the right specialization. +template __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, const int bid_x, @@ -117,9 +121,6 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; } - bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); - bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); - const int input_offset = bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; const int32_t* input_i32 = reinterpret_cast(input) + input_offset; @@ -132,19 +133,37 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, extern __shared__ int slm[]; // load, global -> regs + // Each register read for a given i is along the M direction at K-coord + // (bid_x * TB_DIM * SF_TILE_DIM_K + threadIdx.y * SF_TILE_DIM_K + i). When that + // K-coord is past original_K, the entire register is out of the per-tensor data + // region (which may be the unpadded compact extent), so we must NOT issue the + // __ldg there -- it could read past the per-tensor buffer (and, for the last + // tensor in a grouped allocation, past the end of the allocation entirely). LType regs_vec[N_SF_PER_TD_PER_TILE]; if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && threadIdx.y < k_tiles_in_tb) { + const int k_base = bid_x * TB_DIM * SF_TILE_DIM_K + threadIdx.y * SF_TILE_DIM_K; #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { const int thread_offset = (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + const int k_coord = k_base + i; + if constexpr (IS_PADDED_K) { + if (k_coord >= original_K) { + // Entire register is past original_K: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); +#pragma unroll + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } + } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); - // Pad zeros - if (padding_m || padding_k) { + // Per-byte M masking is still needed when only part of the register is past + // original_M (i.e. K-coord is in range but the M position spans the boundary). + if constexpr (IS_PADDED_M) { for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; - if (index / M >= original_K || index % M >= original_M) { + if (index % M >= original_M) { reinterpret_cast(regs_vec + i)[j] = 0; } } @@ -183,12 +202,43 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, } } +// Dispatch helper: pick the right (IS_PADDED_K, IS_PADDED_M) col-scaling impl +// specialization at runtime based on the per-block padding predicates. The +// branching here is uniform across all threads in the block, so the indirect +// path each block takes still inlines cleanly. +template +__device__ __forceinline__ void dispatch_swizzle_col_scaling_kernel_impl( + const void* input, void* output, const int M, const int K, const int original_M, + const int original_K, const int bid_x, const int bid_y, const int grid_dim_x, + const int grid_dim_y, const bool padding_k, const bool padding_m) { + if (padding_k && padding_m) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_k) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_m) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + template __global__ void __launch_bounds__(TB_DIM* TB_DIM) swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K) { - swizzle_col_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y, + padding_k, padding_m); } template @@ -224,7 +274,11 @@ __device__ inline void regs_unshuffle(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; } -template +// IS_PADDED_K / IS_PADDED_M select the boundary-block specialization at compile +// time so the inner load loop avoids the per-iteration runtime checks. The +// caller computes the runtime predicates from blockIdx/gridDim once per block +// (uniform across the block) and dispatches to the right specialization. +template __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, const int bid_x, @@ -243,9 +297,6 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); - bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); - const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; const int* input_i32 = reinterpret_cast(input) + input_offset; int* output_i32 = reinterpret_cast(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 + @@ -254,17 +305,35 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, extern __shared__ int4 slm_v4i[]; // load, global -> regs + // Each register read for a given i is along the K direction at row + // (bid_y * SF_TILE_DIM_M + i * TB_DIM + threadIdx.y). When that row is past + // original_M, the entire register is out of the per-tensor data region (which + // may be the unpadded compact extent), so we must NOT issue the __ldg there -- + // it could read past the per-tensor buffer (and, for the last tensor in a + // grouped allocation, past the end of the allocation entirely). LType regs_vec[N_SF_PER_TD_PER_TILE]; if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int row = bid_y * SF_TILE_DIM_M + i * TB_DIM + threadIdx.y; const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + if constexpr (IS_PADDED_M) { + if (row >= original_M) { + // Entire register is past original_M: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); +#pragma unroll + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } + } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); - if (padding_m || padding_k) { - // Pad zeros + // Per-byte K masking is still needed when only part of the register is past + // original_K (i.e. row is in range but the K position spans the boundary). + if constexpr (IS_PADDED_K) { +#pragma unroll for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; - if (index / K >= original_M || index % K >= original_K) { + if (index % K >= original_K) { reinterpret_cast(regs_vec + i)[j] = 0; } } @@ -293,12 +362,43 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, } } +// Dispatch helper: pick the right (IS_PADDED_K, IS_PADDED_M) row-scaling impl +// specialization at runtime based on the per-block padding predicates. The +// branching here is uniform across all threads in the block, so the indirect +// path each block takes still inlines cleanly. +template +__device__ __forceinline__ void dispatch_swizzle_row_scaling_kernel_impl( + const void* input, void* output, const int M, const int K, const int original_M, + const int original_K, const int bid_x, const int bid_y, const int grid_dim_x, + const int grid_dim_y, const bool padding_k, const bool padding_m) { + if (padding_k && padding_m) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_k) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_m) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + template __global__ void __launch_bounds__(TB_DIM* TB_DIM) swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K) { - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y, + padding_k, padding_m); } // Narrow-K specialization for row scaling swizzle. @@ -628,14 +728,21 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) grouped_swizzle_row_scaling_uniform_shape_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, - const size_t scale_stride_bytes) { + const size_t input_stride_bytes, + const size_t output_stride_bytes) { const int tensor_id = blockIdx.z; + // Input and output strides may differ: input is in the kernel-produced "compact" + // layout (per-tensor stride = original_M * padded_k * elem_size) when callers + // pass the unswizzled grouped scale buffer as-is, while the output is always in + // the per-tensor padded ("swizzle-ready") layout (padded_m * padded_k * elem_size). const uint8_t* input_base = - reinterpret_cast(input) + tensor_id * scale_stride_bytes; - uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; - swizzle_row_scaling_kernel_impl( + reinterpret_cast(input) + tensor_id * input_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, - gridDim.y); + gridDim.y, padding_k, padding_m); } template @@ -643,14 +750,20 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) grouped_swizzle_col_scaling_uniform_shape_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, - const size_t scale_stride_bytes) { + const size_t input_stride_bytes, + const size_t output_stride_bytes) { const int tensor_id = blockIdx.z; + // See the rowwise kernel for stride semantics. For columnwise the per-tensor + // compact stride is DIVUP(original_K, 1) * padded_m * elem_size (i.e. the + // unpadded scale-row count in the K direction times the padded M extent). const uint8_t* input_base = - reinterpret_cast(input) + tensor_id * scale_stride_bytes; - uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; - swizzle_col_scaling_kernel_impl( + reinterpret_cast(input) + tensor_id * input_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, - gridDim.y); + gridDim.y, padding_k, padding_m); } template @@ -751,8 +864,11 @@ __global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_ const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + const bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + const bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y, padding_k, + padding_m); } template @@ -781,8 +897,11 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; - swizzle_col_scaling_kernel_impl( - input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + const bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + const bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y, padding_k, + padding_m); } template @@ -1924,23 +2043,56 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* const size_t padded_m = round_up_to_multiple(m, 128); const size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - const size_t scale_elems = padded_m * padded_k; + // Per-tensor scale-element counts: + // - "padded" layout: each tensor occupies padded_m * padded_k elements + // (total buffer = num_tensors * padded_m * padded_k). + // - "compact" layout (what the grouped MXFP8 quantize kernel actually writes): + // per-tensor stride is m * padded_k (rowwise) or DIVUP(k,32) * padded_m + // (columnwise) and the total buffer the C++ allocator hands out has its + // grouped first dim padded up to a multiple of 128 (rowwise) or 4 + // (columnwise) -- so the buffer may be slightly larger than + // num_tensors * compact_scale_elems, with trailing alignment slack at + // the very end (never read because of the per-tensor row/k guard in the + // kernel impl). + // The output is always written in the padded layout. The input may be in + // either layout; the kernel handles the compact case safely by using + // different per-tensor strides for input vs output and skipping loads past + // the per-tensor extent. + const size_t padded_scale_elems = padded_m * padded_k; + const size_t compact_scale_elems = + rowwise ? m * padded_k : DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)) * padded_m; + const size_t compact_total_scale_elems = + rowwise ? round_up_to_multiple(input->num_tensors * m, 128) * padded_k + : round_up_to_multiple( + input->num_tensors * DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4) * + padded_m; const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); - const size_t scale_stride_bytes = scale_elems * scale_elem_size; - if (rowwise) { - NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input scale_inv size does not match expected packed size."); - NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output scale_inv size does not match expected packed size."); + const size_t input_scale_numel = + rowwise ? input->scale_inv.numel() : input->columnwise_scale_inv.numel(); + const size_t output_scale_numel = + rowwise ? output->scale_inv.numel() : output->columnwise_scale_inv.numel(); + + bool input_is_compact; + if (input_scale_numel == input->num_tensors * padded_scale_elems) { + input_is_compact = false; + } else if (input_scale_numel == compact_total_scale_elems) { + input_is_compact = true; } else { - NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input columnwise_scale_inv size does not match expected packed size."); - NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output columnwise_scale_inv size does not match expected packed size."); + NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected packed size (got ", input_scale_numel, + ", expected either ", input->num_tensors * padded_scale_elems, + " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); } + NVTE_CHECK(output_scale_numel == input->num_tensors * padded_scale_elems, "Grouped output ", + (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected per-tensor padded size."); + + const size_t input_stride_bytes = + (input_is_compact ? compact_scale_elems : padded_scale_elems) * scale_elem_size; + const size_t output_stride_bytes = padded_scale_elems * scale_elem_size; const int num_tiles_m = padded_m / SF_TILE_DIM_M; const int num_tiles_k = padded_k / SF_TILE_DIM_K; @@ -1963,69 +2115,25 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; if (rowwise) { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } + TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(vec_load_size, LType, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); + }); } else { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } + TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(vec_load_size, LType, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); + }); } NVTE_CHECK_CUDA(cudaGetLastError()); }; diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index cbaabaad17..d8ab830c48 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -403,16 +403,39 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), tensor_offsets.shape); } + // Per-tensor logical dimensions (uniform-shape grouped tensor). + const size_t num_tensors = input.num_tensors(); + const auto logical_shape_nvte = input.logical_shape(); + NVTE_CHECK(logical_shape_nvte.ndim >= 2, + "Grouped GEMM swizzle expects logical_shape with ndim >= 2."); + const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; + const size_t per_tensor_last_dim = logical_shape_nvte.data[logical_shape_nvte.ndim - 1]; + constexpr size_t kMxfp8BlockSize = 32; + + // Output is always allocated in the per-tensor padded ("swizzle-ready") layout + // so the cuDNN grouped GEMM consumer sees the correct stride between experts. + // The swizzle kernel itself handles converting from the kernel-emitted compact + // layout (per-tensor first dim is the unpadded value) to this padded layout. + auto compute_padded_grouped_scale_shape = [&](bool rowwise) { + const size_t m = rowwise ? per_tensor_first_dim : per_tensor_last_dim; + const size_t k = rowwise ? per_tensor_last_dim : per_tensor_first_dim; + const size_t padded_m = ceildiv(m, size_t{128}) * 128; + const size_t padded_k = ceildiv(ceildiv(k, kMxfp8BlockSize), size_t{4}) * 4; + return std::vector{num_tensors * padded_m, padded_k}; + }; + if (swizzle_rowwise) { const auto data = input.get_rowwise_data(); const auto data_dtype = static_cast(data.dtype); const auto scales_dtype = static_cast(row_scales.dtype); swizzle_input.set_rowwise_data(nullptr, data_dtype, data.shape); swizzle_input.set_rowwise_scale_inv(row_scales.data_ptr, scales_dtype, row_scales.shape); - rowwise_scales_pyt = allocateSpace(row_scales.shape, scales_dtype, false); + const auto padded_shape = compute_padded_grouped_scale_shape(/*rowwise=*/true); + rowwise_scales_pyt = allocateSpace(padded_shape, scales_dtype, false); + NVTEShape padded_shape_nvte = nvte_make_shape(padded_shape.data(), padded_shape.size()); swizzle_output.set_rowwise_data(nullptr, data_dtype, data.shape); swizzle_output.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, - row_scales.shape); + padded_shape_nvte); } if (swizzle_columnwise) { const auto data = input.get_columnwise_data(); @@ -420,10 +443,12 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW const auto scales_dtype = static_cast(col_scales.dtype); swizzle_input.set_columnwise_data(nullptr, data_dtype, data.shape); swizzle_input.set_columnwise_scale_inv(col_scales.data_ptr, scales_dtype, col_scales.shape); - columnwise_scales_pyt = allocateSpace(col_scales.shape, scales_dtype, false); + const auto padded_shape = compute_padded_grouped_scale_shape(/*rowwise=*/false); + columnwise_scales_pyt = allocateSpace(padded_shape, scales_dtype, false); + NVTEShape padded_shape_nvte = nvte_make_shape(padded_shape.data(), padded_shape.size()); swizzle_output.set_columnwise_data(nullptr, data_dtype, data.shape); swizzle_output.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, - col_scales.shape); + padded_shape_nvte); } swizzle_output.set_with_gemm_swizzled_scales(true); @@ -434,12 +459,13 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW if (swizzle_rowwise) { const auto scales_dtype = static_cast(row_scales.dtype); - input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, row_scales.shape); + input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, + getTensorShape(*rowwise_scales_pyt)); } if (swizzle_columnwise) { const auto scales_dtype = static_cast(col_scales.dtype); input.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, - col_scales.shape); + getTensorShape(*columnwise_scales_pyt)); } input.set_with_gemm_swizzled_scales(true); return SwizzledGroupedScales{std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)};