diff --git a/problems/nvidia/nvfp4_gemm/eval.py b/problems/nvidia/nvfp4_gemm/eval.py new file mode 100644 index 0000000..e8bb5b2 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/eval.py @@ -0,0 +1,500 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + # Step 1: Compile kernel once before running tests + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Run all tests with compiled kernel + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + try: + # Trigger compilation (will be cached) + compile_kernel() + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel, compile_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + try: + compile_kernel() + torch.cuda.synchronize() + except OpError as E: + return f"Compilation failed: {E}" + except Exception as E: + return f"Compilation failed: {E}" + + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 200 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 10e7) + + # Step 3: Run benchmarks (compilation time excluded) + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 200, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py new file mode 100644 index 0000000..6853098 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -0,0 +1,161 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled GEMM. + """ + a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data + + # Get dimensions from MxNxL layout + _, _, l = c_ref.shape + + # Call torch._scaled_mm to compute the GEMM result + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) + scale_b = to_blocked(sfb_ref_cpu[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_ref[:, :, l_idx], + b_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=torch.float16, + ) + c_ref[:, :, l_idx] = res + return c_ref + + +def generate_input( + m: int, + n: int, + k: int, + l: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled GEMM. + + Args: + m: Number of rows in matrix A + n: Number of columns in matrix B + k: Number of columns in A and rows of B + l: Batch size + seed: Random seed for reproducibility + + Returns: + Tuple of (a, b, scale_a, scale_b, c) where: + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, n, l] - Output matrix in torch.float16 data type + """ + torch.manual_seed(seed) + + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type + a_ref = torch.randint( + -6, 6, (l, m, k // 2), dtype=torch.int8, device="cuda" + ).permute(1, 2, 0) + b_ref = torch.randint( + -6, 6, (l, n, k // 2), dtype=torch.int8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b_ref = b_ref.view(torch.float4_e2m1fn_x2) + + # Create float16 output tensor + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + # Helper function to prepare the scale factor tensors for both reference + # kernel and customize kernel. The customized data layout can be found in: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + def create_scale_factor_tensors(l, mn, sf_k): + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. + ref_shape = (l, mn, sf_k) + ref_permute_order = (1, 2, 0) + # Init with uint8 tensor, then convert to float8_e4m3fn + ref_f8_random_int = torch.randint(-3, 3, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + # permute to match ref_permute_order + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout + # Which is needed by the CuTe customized kernel + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(-3, 3, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] + + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) + sfb_ref_cpu, sfb_ref_permuted = create_scale_factor_tensors(l, n, sf_k) + + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_ref_permuted, sfb_ref_permuted, c_ref) + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py new file mode 100644 index 0000000..c2f37d9 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -0,0 +1,761 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +# Kernel configuration parameters +# Tile sizes for M, N, K dimensions +mma_tiler_mnk = (128, 128, 256) +# Shape of the K dimension for the MMA instruction +mma_inst_shape_k = 64 +# FP4 data type for A and B +ab_dtype = cutlass.Float4E2M1FN +# FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN +# FP16 output type +c_dtype = cutlass.Float16 +# Scale factor block size (16 elements share one scale) +sf_vec_size = 16 +# Number of threads per CUDA thread block +threads_per_cta = 128 +# Stage numbers of shared memory and tmem +num_acc_stage = 1 +num_ab_stage = 1 +# Total number of columns in tmem +num_tmem_alloc_cols = 512 + + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# The CuTe reference implementation for NVFP4 block-scaled GEMM +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + num_tma_load_bytes: cutlass.Constexpr[int], +): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx = cute.arch.thread_idx() + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Define shared storage for kernel + # + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=ab_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + # + # Initialize mainloop ab_pipeline, acc_pipeline and their states + # + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=num_tma_load_bytes, + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=num_acc_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + threads_per_cta, + ), + ).make_participants() + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/SFA/SFB/C + # + # (MMA, MMA_M, MMA_K, RestK) + thr_mma = tiled_mma.get_slice(0) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B/SFA/SFB + # + # TMA Partition_S/D for A + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA Partition_S/D for B + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + # TMA Partition_S/D for SFA + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1), + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + # TMA Partition_S/D for SFB + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + 0, + cute.make_layout(1), + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Alloc tensor memory buffer + # + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + tmem.allocate(num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Make SFA/SFB tmem tensor + # + # Get SFA tmem ptr + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), + dtype=sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + # Get SFB tmem ptr + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # + # Partition for S2T copy of SFA/SFB + # + # Make S2T CopyAtom + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), + sf_dtype, + ) + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFA_compact = cute.filter_zeros(sSFA) + # (MMA, MMA_MN, MMA_K) + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB_compact = cute.filter_zeros(sSFB) + # (MMA, MMA_MN, MMA_K) + tCtSFB_compact = cute.filter_zeros(tCtSFB) + tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) + thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # + # Execute Data copy and Math computation in the k_tile loop + # + if warp_idx == 0: + # Wait for accumulator buffer empty + acc_empty = acc_producer.acquire_and_advance() + # Set ACCUMULATE field to False for the first k_tile iteration + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Execute k_tile loop + for k_tile in range(k_tile_cnt): + # Wait for AB buffer empty + ab_empty = ab_producer.acquire_and_advance() + + # TMA load A/B/SFA/SFB to shared memory + cute.copy( + tma_atom_a, + tAgA[(None, k_tile)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b, + tBgB[(None, k_tile)], + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfa, + tAgSFA[(None, k_tile)], + tAsSFA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfb, + tBgSFB[(None, k_tile)], + tBsSFB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + + # Wait for AB buffer full + ab_full = ab_consumer.wait_and_advance() + + # Copy SFA/SFB from shared memory to TMEM + s2t_stage_coord = (None, None, None, None, ab_full.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_full.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_full.release() + acc_empty.commit() + + # + # Epilogue + # Partition for epilogue + # + op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) + copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rC = cute.make_rmem_tensor( + tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype + ) + # STG Atom + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) + tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] + + # Wait for accumulator buffer full + acc_full = acc_consumer.wait_and_advance() + + # Copy accumulator to register + cute.copy(tiled_copy_t2r, tTR_tAcc, tTR_rAcc) + acc_vec = tTR_rAcc.load().to(c_dtype) + tTR_rC.store(acc_vec) + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC) + + acc_full.release() + + # Deallocate TMEM + cute.arch.barrier() + tmem.free(acc_tmem_ptr) + + return + + +@cute.jit +def my_kernel( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + """ + m, n, k, l = problem_size + + # Setup attributes that depend on gemm inputs + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + b_tensor = cute.make_tensor( + b_ptr, + cute.make_layout( + (n, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) + ) + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) + + mma_op = tcgen05.MmaMXF4NVF4Op( + sf_dtype, + (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute A/B/SFA/SFB/C shared memory layout + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA for A + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + a_tensor, + a_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + # Setup TMA for B + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + b_tensor, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + # Setup TMA for SFA + sfa_smem_layout = cute.slice_( + sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfa_tensor, + sfa_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + # Setup TMA for SFB + sfb_smem_layout = cute.slice_( + sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfb_tensor, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Compute TMA load bytes + a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) + num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Compute grid size + grid = ( + cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), + cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), + c_tensor.shape[2], + ) + + # Launch the kernel + kernel( + # MMA (Matrix Multiply-Accumulate) configuration + tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern + + # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A + tma_atom_a, # TMA copy atom defining how to load A from global memory + tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) + + # TMA atoms and tensors for input matrix B + tma_atom_b, # TMA copy atom defining how to load B from global memory + tma_tensor_b, # Tensor descriptor for B matrix (n, k, l) + + # TMA atoms and tensors for scale factor A + tma_atom_sfa, # TMA copy atom for loading scale factors for A + tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) + + # TMA atoms and tensors for scale factor B + tma_atom_sfb, # TMA copy atom for loading scale factors for B + tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) + + # Output tensor C + c_tensor, # Output tensor C where result will be stored (m, n, l) + + # Shared memory layouts with staging for pipelined execution + a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) + b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) + sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) + sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) + + # Pipeline synchronization parameter + num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) + ).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Global cache for compiled kernel +_compiled_kernel_cache = None +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. +def compile_kernel(): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0)) + + return _compiled_kernel_cache + + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled GEMM kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (a, b, sfa_ref, sfb_ref, sfa_permuted, sfb_permuted, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b: [n, k, l] - Input vector in float4e2m1fn + sfa_ref: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfb_ref: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn + c: [m, n, l] - Output vector in float16 + + Returns: + Output tensor c with computed results + """ + a, b, _, _, sfa_permuted, sfb_permuted, c = data + + # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. + compiled_func = compile_kernel() + + # Get dimensions from MxKxL layout + m, k, l = a.shape + n, _, _ = b.shape + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + + # Execute the compiled kernel + compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/task.py b/problems/nvidia/nvfp4_gemm/task.py new file mode 100644 index 0000000..66db735 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/task.py @@ -0,0 +1,11 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + n: int + k: int + l: int + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/template.py b/problems/nvidia/nvfp4_gemm/template.py new file mode 100644 index 0000000..3855d69 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/template.py @@ -0,0 +1,25 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp4 gemm + Args: + data: Tuple that expands to: + a: torch.Tensor[float4e2m1fn] of shape [m, k, l], + b: torch.Tensor[float4e2m1fn] of shape [n, k, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], + sfb: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], + sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], + sfb_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], + c: torch.Tensor[float16] of shape [m, n, l] + Returns: + Tensor containing output in float16 + c: torch.Tensor[float16] of shape [m, n, l] + """ + # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. + a, b, sfa, sfb, sfa_permuted, sfb_permuted, c = data + + # Your implementation here + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/utils.py b/problems/nvidia/nvfp4_gemm/utils.py new file mode 100644 index 0000000..d9b3a69 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/utils.py @@ -0,0 +1,172 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy