Skip to content
255 changes: 251 additions & 4 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const
const NVTEShape rs = input->rowwise_scale_inv_shape();
zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr<uint8_t>(),
rs.data[0], rs.data[1],
M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE);
M, divide_round_up(K, BLOCK_SIZE));
const NVTEShape cs = input->columnwise_scale_inv_shape();
zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr<uint8_t>(),
cs.data[0], cs.data[1],
(M + BLOCK_SIZE - 1) / BLOCK_SIZE, K);
divide_round_up(M, BLOCK_SIZE), K);
input->from_cpu();

input_ptrs.push_back(input.get());
Expand Down Expand Up @@ -444,11 +444,11 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si
const NVTEShape rs = orig->rowwise_scale_inv_shape();
zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr<uint8_t>(),
rs.data[0], rs.data[1],
M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE);
M, divide_round_up(K, BLOCK_SIZE));
const NVTEShape cs = orig->columnwise_scale_inv_shape();
zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr<uint8_t>(),
cs.data[0], cs.data[1],
(M + BLOCK_SIZE - 1) / BLOCK_SIZE, K);
divide_round_up(M, BLOCK_SIZE), K);
orig->from_cpu();

orig_ptrs.push_back(orig.get());
Expand Down Expand Up @@ -541,6 +541,253 @@ INSTANTIATE_TEST_SUITE_P(
}
);

// Build a "compact" grouped MXFP8 scale_inv buffer for swizzle input. This is
// the layout produced by the grouped MXFP8 quantize kernel: the per-tensor
// stride is `M_per_tensor * padded_K` (rowwise) or `DIVUP(M,32) * padded_K_for_cols`
// (columnwise) -- i.e. NO per-tensor padding rows are inserted. The total buffer
// is rounded up at its very end to a multiple of 128 (rowwise) or 4 (columnwise)
// in the grouped first dim, matching what the C++ allocator hands out.
//
// Each tensor's compact scales are gathered from the unpadded-prefix rows of
// that tensor's per-tensor padded CPU scale buffer.
namespace {

struct CompactScaleBuffer {
test::CudaPtr<> ptr;
size_t numel{0};
};

CompactScaleBuffer gather_compact_grouped_scale(
const std::vector<std::unique_ptr<test::Tensor>>& tensors,
size_t M_per_tensor, size_t K_per_tensor, bool rowwise) {
using namespace test;
constexpr size_t BLOCK = 32;
const size_t num_tensors = tensors.size();

size_t per_tensor_first_unpadded;
size_t per_tensor_last_padded;
size_t group_first_align;
if (rowwise) {
per_tensor_first_unpadded = M_per_tensor;
per_tensor_last_padded =
round_up_to_nearest_multiple(divide_round_up(K_per_tensor, BLOCK), 4);
group_first_align = 128;
} else {
per_tensor_first_unpadded = divide_round_up(M_per_tensor, BLOCK);
per_tensor_last_padded = round_up_to_nearest_multiple(K_per_tensor, 128);
group_first_align = 4;
}

const size_t per_tensor_compact_numel =
per_tensor_first_unpadded * per_tensor_last_padded;
const size_t total_first = round_up_to_nearest_multiple(
num_tensors * per_tensor_first_unpadded, group_first_align);
const size_t total_numel = total_first * per_tensor_last_padded;

std::vector<uint8_t> host_buf(total_numel, 0);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
const NVTEShape padded_shape = rowwise ? tensors[i]->rowwise_scale_inv_shape()
: tensors[i]->columnwise_scale_inv_shape();
NVTE_CHECK(padded_shape.data[1] == per_tensor_last_padded,
"Unexpected per-tensor padded last dim in compact gather.");
const uint8_t* src = rowwise
? tensors[i]->rowwise_cpu_scale_inv_ptr<uint8_t>()
: tensors[i]->columnwise_cpu_scale_inv_ptr<uint8_t>();
uint8_t* dst = host_buf.data() + i * per_tensor_compact_numel;
// Per-tensor padded buffer is row-major (padded_first, padded_last); copy
// only the first `per_tensor_first_unpadded` rows.
std::memcpy(dst, src, per_tensor_compact_numel);
}

CompactScaleBuffer out;
out.ptr = cuda_alloc(total_numel);
NVTE_CHECK_CUDA(cudaMemcpy(out.ptr.get(), host_buf.data(),
total_numel, cudaMemcpyHostToDevice));
out.numel = total_numel;
return out;
}

} // namespace

// Tests that grouped_swizzle_for_gemm correctly handles a COMPACT input
// scale_inv buffer (no per-tensor padding rows), producing an output in the
// per-tensor padded layout with padded regions zeroed out. This is the layout
// produced by the grouped MXFP8 quantize kernel; previously the swizzle kernel
// asserted the input matched the per-tensor padded packed size, which broke
// grouped MLP weights with M not a multiple of 128.
void performTestGroupedSwizzleMXFP8CompactInput(const int num_tensors, const size_t M,
const size_t K) {
using namespace transformer_engine;
using namespace test;

std::vector<std::unique_ptr<Tensor>> input_tensors;
std::vector<std::unique_ptr<Tensor>> output_tensors;
std::vector<Tensor*> input_ptrs, output_ptrs;
input_tensors.reserve(num_tensors);
output_tensors.reserve(num_tensors);
input_ptrs.reserve(num_tensors);
output_ptrs.reserve(num_tensors);

constexpr size_t BLOCK_SIZE = 32;
const std::vector<size_t> shape{M, K};
for (int i = 0; i < num_tensors; ++i) {
auto input = std::make_unique<Tensor>("input_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true,
NVTE_MXFP8_1D_SCALING);
auto output = std::make_unique<Tensor>("output_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true,
NVTE_MXFP8_1D_SCALING);
fillUniform(input.get());
fillUniform(output.get());

// Zero the per-tensor padded regions so the reference (which sees the
// padded layout) and the kernel (which sees the compact layout but writes
// zeros into output padding) agree byte-for-byte.
input->to_cpu();
const NVTEShape rs = input->rowwise_scale_inv_shape();
zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr<uint8_t>(),
rs.data[0], rs.data[1],
M, divide_round_up(K, BLOCK_SIZE));
const NVTEShape cs = input->columnwise_scale_inv_shape();
zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr<uint8_t>(),
cs.data[0], cs.data[1],
divide_round_up(M, BLOCK_SIZE), K);
input->from_cpu();

input_ptrs.push_back(input.get());
output_ptrs.push_back(output.get());
input_tensors.emplace_back(std::move(input));
output_tensors.emplace_back(std::move(output));
}

// Build a per-tensor padded grouped output via the standard helper, and a
// compact-scale grouped input by overriding the scale_inv buffers of a
// padded grouped input with newly allocated compact buffers.
GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING);

CompactScaleBuffer compact_row =
gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/true);
CompactScaleBuffer compact_col =
gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/false);

grouped_input.scale_inv = std::move(compact_row.ptr);
grouped_input.columnwise_scale_inv = std::move(compact_col.ptr);
{
NVTEShape s = nvte_make_shape(&compact_row.numel, 1);
NVTEBasicTensor t{grouped_input.scale_inv.get(), kNVTEFloat8E8M0, s};
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
kNVTEGroupedRowwiseScaleInv, &t, sizeof(t));
}
{
NVTEShape s = nvte_make_shape(&compact_col.numel, 1);
NVTEBasicTensor t{grouped_input.columnwise_scale_inv.get(), kNVTEFloat8E8M0, s};
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
kNVTEGroupedColumnwiseScaleInv, &t, sizeof(t));
}

const uint8_t input_swizzled = 0;
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
&input_swizzled, sizeof(input_swizzled));
const uint8_t output_swizzled = 1;
nvte_set_grouped_tensor_param(grouped_output.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
&output_swizzled, sizeof(output_swizzled));

const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape();
const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape();
const size_t row_numel = row_shape.data[0] * row_shape.data[1];
const size_t col_numel = col_shape.data[0] * col_shape.data[1];

// Memset to a non-zero sentinel so we can detect kernel failures to write
// padded regions (those must be overwritten with zero by the kernel).
NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0xCD,
num_tensors * row_numel));
NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0xCD,
num_tensors * col_numel));

nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(),
grouped_output.get_handle(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

std::vector<uint8_t> output_row(num_tensors * row_numel);
std::vector<uint8_t> output_col(num_tensors * col_numel);
NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(),
output_row.size(), cudaMemcpyDeviceToHost));
NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(),
grouped_output.columnwise_scale_inv.get(),
output_col.size(), cudaMemcpyDeviceToHost));

std::vector<uint8_t> ref_row(num_tensors * row_numel);
std::vector<uint8_t> ref_col(num_tensors * col_numel);
for (int i = 0; i < num_tensors; ++i) {
compute_ref_swizzle<128, 4, true>(
input_tensors[i]->rowwise_cpu_scale_inv_ptr<uint8_t>(),
ref_row.data() + i * row_numel,
row_shape.data[0], row_shape.data[1]);
compute_ref_swizzle<128, 4, false>(
input_tensors[i]->columnwise_cpu_scale_inv_ptr<uint8_t>(),
ref_col.data() + i * col_numel,
col_shape.data[1], col_shape.data[0]);
}

compareResults("grouped_swizzle_compact_rowwise", output_row.data(),
ref_row.data(), num_tensors * row_numel);
compareResults("grouped_swizzle_compact_colwise", output_col.data(),
ref_col.data(), num_tensors * col_numel);
}

class SwizzleGroupedCompactInputTestSuite
: public ::testing::TestWithParam<std::tuple<int, size_t, size_t>> {};

TEST_P(SwizzleGroupedCompactInputTestSuite, TestGroupedSwizzleMXFP8CompactInput) {
const auto num_tensors = std::get<0>(GetParam());
const auto M = std::get<1>(GetParam());
const auto K = std::get<2>(GetParam());
performTestGroupedSwizzleMXFP8CompactInput(num_tensors, M, K);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwizzleGroupedCompactInputTestSuite,
::testing::Values(
// Aligned M and K. Per-tensor compact stride == per-tensor padded stride,
// so the kernel may use either layout; serves as a sanity check that the
// compact-input plumbing doesn't regress aligned shapes.
std::make_tuple(3, 256, 256),
std::make_tuple(4, 128, 128),
// M NOT divisible by 128 (the original-bug case): per-tensor compact stride
// shrinks vs padded. We pick (num_tensors, M) so that BOTH
// round_up(N * M, 128) != N * round_up(M, 128) (rowwise)
// round_up(N * DIVUP(M,32), 4) != N * round_up(DIVUP(M,32),4) (colwise)
// i.e. compact_total != padded_total on either axis, so the kernel
// unambiguously detects the compact layout.
std::make_tuple(4, 200, 256),
std::make_tuple(4, 65, 256),
std::make_tuple(2, 2880, 2880), // shape from the originally failing workload
// K not divisible by 128 (DIVUP(K,32) padded up to a multiple of 4).
std::make_tuple(3, 256, 160),
std::make_tuple(2, 256, 96),
// Neither M nor K aligned.
std::make_tuple(4, 200, 160),
std::make_tuple(4, 33, 64),
std::make_tuple(2, 1, 32),
// num_tensors * M not aligned to 128 -> exercises trailing alignment slack
// at the end of the compact rowwise buffer.
std::make_tuple(3, 64, 128),
std::make_tuple(5, 33, 96)
),
[](const testing::TestParamInfo<SwizzleGroupedCompactInputTestSuite::ParamType>& info) {
return "n" + std::to_string(std::get<0>(info.param)) +
"_M" + std::to_string(std::get<1>(info.param)) +
"_K" + std::to_string(std::get<2>(info.param));
}
);

class UnswizzleGroupedTestSuite
: public ::testing::TestWithParam<std::tuple<int, size_t, size_t>> {};

Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,26 @@ struct TypeInfo {
} \
}

#define TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(INTEGER_ELTS_NUM, type, ...) \
switch (INTEGER_ELTS_NUM) { \
case 1: { \
using type = int; \
{ __VA_ARGS__ } \
} break; \
case 2: { \
using type = int2; \
{ __VA_ARGS__ } \
} break; \
case 4: { \
using type = int4; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Unsupported number of integer elements ", INTEGER_ELTS_NUM, \
". Expected one of: 1, 2, or 4."); \
} \
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline int log2_ceil(int value) {
Expand Down
Loading
Loading