diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 615f173b9..04d849115 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: args: [--ignore-case] files: ^docs/spelling_wordlist\.txt$ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v21.1.2 # sync with requirements-lint.txt + rev: v21.1.6 # sync with requirements-lint.txt hooks: - id: clang-format exclude: | @@ -41,7 +41,7 @@ repos: ^.+\.json$ ) - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.3 # sync with requirements-lint.txt + rev: v0.14.7 # sync with requirements-lint.txt hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] diff --git a/3rdparty/tvm b/3rdparty/tvm index fc7ed0b9c..e8b02611f 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit fc7ed0b9cb7a52eb1c8bf6e8c26bbb8dff3655ce +Subproject commit e8b02611fd6b803273c5c3e15aa3a030c32dbd30 diff --git a/README.md b/README.md index d7cdabee5..0c0769e7d 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) print("Kernel output matches PyTorch reference.") # 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() +# cuda_source = matmul_relu_kernel.get_kernel_source() # print("Generated CUDA kernel:\n", cuda_source) # 5.Profile latency with kernel diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py index 4e4ed6128..0ff3cd0b6 100644 --- a/benchmark/matmul/benchmark_matmul_sp.py +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -9,7 +9,7 @@ from tilelang.autotuner import autotune from tilelang import jit from tilelang.contrib import nvcc -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout # Configure logger logger = logging.getLogger(__name__) @@ -86,7 +86,7 @@ def get_configs(M, N, K): return configs -def matmul_sp(M, N, K, accum_dtype): +def matmul_sp(M, N, K, in_dtype, accum_dtype): """ Create an autotuned matrix multiplication kernel for matrices of shape: - A: (M, K) @@ -161,14 +161,13 @@ def kernel( """ # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float16" e_factor, e_dtype = ARCH_INFO[arch] @T.prim_func def main( - A_sparse: T.Tensor((M, K // 2), dtype), + A_sparse: T.Tensor((M, K // 2), in_dtype), E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), dtype), + B: T.Tensor((K, N), in_dtype), C: T.Tensor((M, N), accum_dtype), ): """ @@ -187,9 +186,9 @@ def main( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) - A_shared = T.alloc_shared((block_M, block_K // 2), dtype) + A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) - B_shared = T.alloc_shared((block_K, block_N), dtype) + B_shared = T.alloc_shared((block_K, block_N), in_dtype) # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) # Allocate a local fragment for intermediate accumulation @@ -204,11 +203,9 @@ def main( T.use_swizzle(panel_size=10, enable=enable_rasterization) T.annotate_layout({ E: - make_metadata_layout( - E, mma_dtype="float16", backend="cutlass", block_k=block_K), + make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), E_shared: - make_metadata_layout( - E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K), + make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), }) # Loop over sub-blocks in K dimension, pipelined by num_stages for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -220,7 +217,7 @@ def main( T.copy(B[k * block_K, bx * block_N], B_shared) # Perform a partial matrix multiplication: # C_local += A_shared @ B_shared - T.gemm_sp( + T.gemm_sp_v2( A_shared, E_shared, B_shared, @@ -268,7 +265,7 @@ def main( total_flops = 2 * M * N * K # matmul(...) returns (best_latency, best_config, ref_latency) - best_result = matmul_sp(M, N, K, args.accum_dtype) + best_result = matmul_sp(M, N, K, "float16", args.accum_dtype) best_latency = best_result.latency best_config = best_result.config A = torch.randn(M, K, dtype=torch.float16, device="cuda") diff --git a/docs/_static/img/sparse_mma_storage_example.png b/docs/_static/img/sparse_mma_storage_example.png new file mode 100644 index 000000000..0b1639819 Binary files /dev/null and b/docs/_static/img/sparse_mma_storage_example.png differ diff --git a/docs/deeplearning_operators/matmul_sparse.md b/docs/deeplearning_operators/matmul_sparse.md new file mode 100644 index 000000000..5910bd6f8 --- /dev/null +++ b/docs/deeplearning_operators/matmul_sparse.md @@ -0,0 +1,262 @@ +# Sparse Matrix-Matrix Multiplication with Tile Library + +
+ Author: botbw +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + + This feature is still **experimental** and need further optimization. + + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +:::{tip} +It's suggested to go through `docs/deeplearning_operators/matmul.md` first. + +Example code can be found at `examples/gemm_sp`. +::: + +## Structured sparsity in the NVIDIA Ampere architecture + +Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation. + +:::{warning} + This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X. +::: + +```{figure} ../_static/img/sparse_mma_storage_example.png +:align: center + +Figure: Sparse MMA storage example (from PTX doc) +``` + +## Compress a dense tensor + +To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata. + +Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`). + +A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression. + +```python +from tilelang.utils.sparse import compress +A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) +``` + +Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. + +> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor) +The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads). +For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**. + + +## `T.gemm_sp` with CUTLASS's compressor + +:::{warning} + +It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time. + +::: + +A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata. + +Check comments in below kernel code for required modification. + +```python +def matmul_sp_sm80( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + is_8_bit = "8" in in_dtype + metadata_dtype = 'int32' if is_8_bit else 'int16' + E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ # Annotate reordered cutlass metadata layout + E: + make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: + make_cutlass_metadata_layout( + E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main +``` + +Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`. + +## `T.gemm_sp_v2` with a custom compressor + +To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`. + +Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors. + +The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs. + +Suppose we have the following row vector: +```python +t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten() +``` + +The non-zero elements and their corresponding indices are: + +```python +t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten() +indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten() +``` + +The corresponding uint16 metadata is: +```python +# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000]) +# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16) +# Note: the above code is not runnable in python as the interpreter won't take the binary +# as 2's complement +metadata_int16 = tensor(-29107) +``` + +You can decode an int16 metadata tensor using the following utility: +```python +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) +``` + +The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level. + +For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function. + +If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel. + +```python + +@tilelang.jit(out_idx=[1, 2], pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, +}) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: # NOTE: Make sure compressor metadata layout + T.annotate_layout({ # is same with your computation kernel + E: + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: + make_cutlass_metadata_layout( + E_shared, + mma_dtype="float16", + arch="8.0", + block_k=block_K), + }) + T.clear(A_sp_shared) + T.clear(E_shared) + non_zero_cnt = T.alloc_local((1, ), dtype="uint8") + non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8") + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + T.clear(non_zero_cnt) + T.clear(non_zero_elt_log_idx) + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel +``` + +## A note on `gemm_sp` and `gemm_sp_v2` + +Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout. + +However, fixing a specific layout introduces several potential issues: + +1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling. + +2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically. + +3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.) + +`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout. \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 9f7947766..45e7f5ea4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,6 +33,7 @@ tutorials/auto_tuning deeplearning_operators/elementwise deeplearning_operators/gemv deeplearning_operators/matmul +deeplearning_operators/matmul_sparse deeplearning_operators/deepseek_mla ::: diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index ed44aab69..0e2c437e3 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -51,7 +51,12 @@ def tl_matmul( micro_size_x = micro_size_y = micro_size_k = 16 - is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] + is_float8 = in_dtype in [ + "float8_e4m3", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fnuz", + ] if out_dtype == "int32" or is_float8: micro_size_k = 32 diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py new file mode 100644 index 000000000..5125aed07 --- /dev/null +++ b/examples/gemm_sp/example_custom_compress.py @@ -0,0 +1,364 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +import argparse + +import tilelang +import tilelang.language as T + +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import randn_semi_sparse +from tilelang.utils.tensor import torch_assert_close + +from triton.testing import do_bench + +import torch + +torch.manual_seed(42) + +DEFAULT_CONFIG = { # take best config from autotune script + "4090": { + 'float': { + 'block_M': 128, + 'block_N': 64, + 'block_K': 64, + 'num_stages': 1, + 'thread_num': 128, + 'policy': T.GemmWarpPolicy.Square, + 'enable_rasterization': True + }, + 'float16': { + 'block_M': 256, + 'block_N': 128, + 'block_K': 64, + 'num_stages': 2, + 'thread_num': 128, + 'policy': T.GemmWarpPolicy.Square, + 'enable_rasterization': True + } + }, + "h20": { + 'float': { + 'block_M': 128, + 'block_N': 64, + 'block_K': 128, + 'num_stages': 3, + 'thread_num': 128, + 'policy': T.GemmWarpPolicy.Square, + 'enable_rasterization': True + }, + 'float16': { + 'block_M': 128, + 'block_N': 64, + 'block_K': 128, + 'num_stages': 3, + 'thread_num': 128, + 'policy': T.GemmWarpPolicy.Square, + 'enable_rasterization': True + } + } +} + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, + thread_num, policy, enable_rasterization, use_cutlass_layout): + e_factor, e_dtype = (16, "int16") + + @T.prim_func + def gemm_sp_fp16_custom_compress( + A_sparse: T.Tensor((M, K // 2), 'float16'), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + B_shared = T.alloc_shared((block_K, block_N), 'float16') + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + if use_cutlass_layout: + T.annotate_layout({ + E: + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: + make_cutlass_metadata_layout( + E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + }) + T.clear(C_local) + T.disable_warp_group_reg_alloc() + T.use_swizzle(panel_size=10, enable=enable_rasterization) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_sp_fp16_custom_compress + + +def torch_compress(dense): + """ + A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. + """ + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") + + m, k = dense.shape + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + return (sparse, meta) + + +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 # 4 groups per uint16 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) + + +@tilelang.jit( + out_idx=[1, 2], pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, + }) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: + T.annotate_layout({ + E: + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: + make_cutlass_metadata_layout( + E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + }) + T.clear(A_sp_shared) + T.clear(E_shared) + # TODO: alloc_var seems buggy here + non_zero_cnt = T.alloc_local((1,), dtype="uint8") + non_zero_elt_log_idx = T.alloc_local((elem,), dtype="uint8") + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + non_zero_cnt[0] = 0 + for i in range(elem): + non_zero_elt_log_idx[i] = 0 + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left( + val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel + + +def main(): + + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument( + "--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor") + parser.add_argument( + "--use_torch_compressor", action='store_true', help="Use torch sparse for reference") + parser.add_argument( + "--accum_dtype", + type=str, + default="float", + choices=["float", "float16"], + help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") + args = parser.parse_args() + kernel = matmul_sp_fp16_custom_compress( + args.m, + args.n, + args.k, + args.accum_dtype, + **DEFAULT_CONFIG[args.cfg][args.accum_dtype], + use_cutlass_layout=args.use_cutlass_layout) + + a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) + b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) + + if args.use_torch_compressor: + assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" + a_sparse, e = torch_compress(a) + else: + a_sparse, e = compress_kernel( + args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)( + a) + + c = kernel(a_sparse, e, b) + + ref_c = a @ b + + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" + torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) + print( + f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}" + ) + + latency = do_bench(lambda: kernel(a_sparse, e, b)) + ref_latency = do_bench(lambda: a @ b) + + total_flops = 2 * args.m * args.n * args.k + tflops = total_flops / latency / 1e9 + ref_tflops = total_flops / ref_latency / 1e9 + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 505f2b883..91682a9e4 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -5,7 +5,7 @@ import tilelang import tilelang.language as T -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout from tilelang.utils.sparse import compress, randn_semi_sparse from tilelang.contrib import nvcc from triton.testing import do_bench @@ -14,9 +14,7 @@ arch = nvcc.get_target_compute_version() -ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} - -default_config = { # take best config from autotune script +DEFAULT_CONFIG = { # take best config from autotune script "4090": { 'float': { 'block_M': 128, @@ -59,6 +57,8 @@ } } +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + @tilelang.jit(out_idx=[-1]) def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, @@ -84,15 +84,11 @@ def gemm_sp_fp16( T.use_swizzle(panel_size=10, enable=enable_rasterization) T.annotate_layout({ E: - make_metadata_layout( - E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch), + make_cutlass_metadata_layout( + E, mma_dtype="float16", block_k=block_K, arch=arch), E_shared: - make_metadata_layout( - E_shared, - mma_dtype="float16", - backend="cutlass", - block_k=block_K, - arch=arch), + make_cutlass_metadata_layout( + E_shared, mma_dtype="float16", block_k=block_K, arch=arch), }) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) @@ -117,10 +113,10 @@ def main(): default="float", choices=["float", "float16"], help="Accumulation datatype") - parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True) + parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") args = parser.parse_args() kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, - **default_config[args.cfg][args.accum_dtype]) + **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) @@ -128,7 +124,7 @@ def main(): a_sparse, e = compress( a, transposed=False, - block_k=default_config[args.cfg][args.accum_dtype]['block_K'], + block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'], arch=arch) c = kernel(a_sparse, e, b) diff --git a/examples/gemm_sp/test_example_gemm_sp.py b/examples/gemm_sp/test_example_gemm_sp.py new file mode 100644 index 000000000..fe26df144 --- /dev/null +++ b/examples/gemm_sp/test_example_gemm_sp.py @@ -0,0 +1,16 @@ +import tilelang.testing + +import example_custom_compress +import example_gemm_sp + + +def test_example_custom_compress(): + example_custom_compress.main() + + +def test_example_gemm_sp(): + example_gemm_sp.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 58e0114be..3772dc6bd 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -360,7 +360,7 @@ def main(do_bench: bool = True): print("Test passed!") - if not do_bench: + if do_bench: best_result = get_autotuned_kernel(N, K) best_config = best_result.config kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py index 59c79c283..8707c9430 100644 --- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -1,7 +1,7 @@ import torch import tilelang from tilelang.utils.sparse import compress_sm90 -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout import tilelang.testing @@ -40,15 +40,11 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.annotate_layout({ E: - make_metadata_layout( - E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K), + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="9.0", block_k=block_K), E_shared: - make_metadata_layout( - E_shared, - mma_dtype="float16", - arch="9.0", - backend="cutlass", - block_k=block_K), + make_cutlass_metadata_layout( + E_shared, mma_dtype="float16", arch="9.0", block_k=block_K), }) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 1a49b7706..4ae19bafc 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -539,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } -TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) +TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd) .set_num_inputs(2) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/copy.cc b/src/op/copy.cc index 1bd548bc5..7bef87d64 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) { } } else if (dtype.is_bfloat16()) { tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) { + } else if (dtype.is_float8()) { tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if (dtype.is_int()) { switch (dtype.bits()) { @@ -2037,7 +2037,7 @@ Array TMAIm2ColDesc::EncodeCallArgs() const { // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, // eviction_policy // - Marked as opaque since it has side effects (memory writes) -TIR_REGISTER_TL_OP(Copy, copy) +TIR_REGISTER_TL_TILE_OP(Copy, copy) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -2062,7 +2062,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, // - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, // dilation, padding, eviction_policy // - Marked as opaque since it has side effects (memory writes) -TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) +TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col) .set_num_inputs(9) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/fill.cc b/src/op/fill.cc index 5a773768a..714e97ad2 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -209,7 +209,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, return {}; } -TIR_REGISTER_TL_OP(Fill, fill) +TIR_REGISTER_TL_TILE_OP(Fill, fill) .set_num_inputs(2) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index effc4baf0..f542b2d91 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -159,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const { return TileOperator(node); } -TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer) +TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 5a98cba69..57c02b0b5 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -361,13 +361,7 @@ bool GemmNode::checkWgmma() const { if (c_->dtype == DataType::Float(16)) { if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) return k_ % 16 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2()) + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) return (!transA_) && transB_ && k_ % 32 == 0; else return false; @@ -380,13 +374,7 @@ bool GemmNode::checkWgmma() const { else if (a_->dtype == DataType::Float(32) && b_->dtype == DataType::Float(32)) return (!transA_) && transB_ && k_ % 8 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2()) + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) return (!transA_) && transB_ && k_ % 32 == 0; else return false; @@ -826,7 +814,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, return results; } -TIR_REGISTER_TL_OP(Gemm, gemm) +TIR_REGISTER_TL_TILE_OP(Gemm, gemm) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index aa6c02823..378fcc6a5 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -182,13 +182,7 @@ bool GemmPyNode::checkWgmma() const { if (c_->dtype == DataType::Float(16)) { if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) return k_ % 16 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2()) + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) return (!transA_) && transB_ && k_ % 32 == 0; else return false; @@ -201,13 +195,7 @@ bool GemmPyNode::checkWgmma() const { else if (a_->dtype == DataType::Float(32) && b_->dtype == DataType::Float(32)) return (!transA_) && transB_ && k_ % 8 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3()) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2()) + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) return (!transA_) && transB_ && k_ % 32 == 0; else return false; @@ -318,7 +306,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, return results; } -TIR_REGISTER_TL_OP(GemmPy, gemm_py) +TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index df923d0e9..4c0ae08b9 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -302,12 +302,25 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, return results; } -TIR_REGISTER_TL_OP(GemmSP, gemm_sp) +TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); } +TVM_REGISTER_OP("tl.GemmSPWarpPolicy") + .set_attr("TScriptPrinterName", "GemmSPWarpPolicy"); +TVM_FFI_STATIC_INIT_BLOCK() { + GemmSPNode::RegisterReflection(); + GemmSPWarpPolicyNode::RegisterReflection(); + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tl.GemmSPWarpPolicyComputeWarpPartition", + [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target, + bool use_wgmma, int bits) { + policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits); + return; + }); +} } // namespace tl } // namespace tvm diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index aae5b27bf..a634e922f 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -23,6 +23,14 @@ class GemmSPWarpPolicyNode : public GemmWarpPolicyNode { int bits) const; TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode, GemmWarpPolicyNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("policy_type", &GemmSPWarpPolicyNode::policy_type) + .def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp) + .def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp); + } }; class GemmSPWarpPolicy : public ObjectRef { diff --git a/src/op/gemm_sp_py.cc b/src/op/gemm_sp_py.cc new file mode 100644 index 000000000..6ad8ca9b5 --- /dev/null +++ b/src/op/gemm_sp_py.cc @@ -0,0 +1,289 @@ +/*! + * \file tl/op/gemm_sp_py.cc + * \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP) + * operators + */ + +#include "gemm_sp_py.h" +#include "utils.h" + +#include "builtin.h" +#include +#include +#include +#include + +#include "../target/utils.h" +#include "tvm/ffi/string.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct a Gemm operator from serialized TL arguments and a buffer + * map. + * + * This constructor deserializes operator parameters from `args` and resolves + * buffer references via `vmap`, populating an internal GemmSPPyNode with: + * - device pointers for A, E, B, C and their corresponding Buffer objects, + * - transpose flags for A and B, + * - matrix dimensions M, N, K, + * - warp allocation policy and clear_accum flag, + * - strides and memory offsets for A and B, + * - optional kPack (must be 1 or 2) and optional wg_wait. + * + * The populated GemmSPPyNode is stored into the wrapper's internal `data_`. + * + * @param args Positional serialized arguments produced by the TL frontend: + * expected layout is: + * [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), + * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), + * (optional) kPack (Int), (optional) wg_wait (Int)] + * @param vmap Mapping from access pointer vars to Buffer objects used to + * resolve the Buffer corresponding to each pointer argument. + * + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here. + */ +GemmSPPy::GemmSPPy(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->eRegion_ = NormalizeToBufferRegion(args[1]); + node->bRegion_ = NormalizeToBufferRegion(args[2]); + node->cRegion_ = NormalizeToBufferRegion(args[3]); + + node->A = node->aRegion_->buffer; + node->E = node->eRegion_->buffer; + node->B = node->bRegion_->buffer; + node->C = node->cRegion_->buffer; + + node->trans_A = args[4].as().value(); + node->trans_B = args[5].as().value(); + node->trans_E = args[6].as().value(); + node->M = args[7].as().value()->value; + node->N = args[8].as().value()->value; + node->K = args[9].as().value()->value; + node->policy = GemmWarpPolicy(args[10].as().value()->value); + node->clear_accum = args[11].as().value(); + node->stride_A = args[12].as().value()->value; + node->stride_B = args[13].as().value()->value; + node->offset_A = args[14].as().value()->value; + node->offset_B = args[15].as().value()->value; + if (args.size() > 16) { + node->kPack = args[16].as().value()->value; + if (node->kPack != 1 && node->kPack != 2) { + ICHECK(false) << "kPack must be 1 or 2"; + } + } + if (args.size() > 17) { + node->wg_wait = args[17].as().value()->value; + } + data_ = std::move(node); +} + +/** + * @brief Create a copy of this GemmSPPyNode as a TileOperator. + * + * Constructs a new GemmSPPyNode by copying the current node state and returns + * it wrapped in a GemmSPPy TileOperator. + * + * @return TileOperator A GemmSPPy operator that owns a copy of this node. + */ +TileOperator GemmSPPyNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return GemmSPPy(op); +} + +GemmInst GemmSPPyNode::GetGemmInst(int block_size, Target target) const { + int warp_size = TargetGetWarpSize(target); + int num_warps = block_size / warp_size; + bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && + (num_warps % 4 == 0) && CheckWGMMA(); + if (allow_wgmma) { + return GemmInst::kWGMMA; + } else if (TargetIsCDNA(target)) { + return GemmInst::kMFMA; + } else if (TargetIsCuda(target)) { + return GemmInst::kMMA; + } else { + ICHECK(0) << "Unsupported target for gemm: " << target->str(); + } +} + +/** + * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. + * + * Evaluates device-memory placement, data-type combinations, transpose flags, + * and K divisibility constraints required for the Hopper WGMMA code path. + * + * The check returns true only when: + * - B resides in shared memory ("shared" or "shared.dyn"); and + * - (C, A, B) dtypes match one of the supported combinations below and K + * satisfies the required alignment; and + * - for combinations that require specific orientations, A is not transposed + * and B is transposed. + * + * Supported combinations and constraints: + * - C=float16: + * - A=float16, B=float16: K % 16 == 0 + * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % + * 32 == 0 + * - C=float32: + * - A=float16, B=float16: K % 16 == 0 + * - A=bfloat16, B=bfloat16: K % 16 == 0 + * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 + * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 + * - C=int32: + * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) + * and K % 32 == 0 + * + * @return true if WGMMA is supported for the current buffers, dtypes, and + * transpose/shape constraints; false otherwise. + */ +bool GemmSPPyNode::CheckWGMMA() const { + return false; // not supported yet + // if (B.scope() != "shared.dyn" && B.scope() != "shared") { + // return false; + // } + + // if (C->dtype == DataType::Float(16)) { + // if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) + // return K % 16 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else + // return false; + // } else if (C->dtype == DataType::Float(32)) { + // if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) + // return K % 16 == 0; + // else if (A->dtype == DataType::BFloat(16) && + // B->dtype == DataType::BFloat(16)) + // return K % 16 == 0; + // else if (A->dtype == DataType::Float(32) && B->dtype == + // DataType::Float(32)) + // return (!trans_A) && trans_B && K % 8 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else + // return false; + // } else if (C->dtype == DataType::Int(32)) { + // if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else + // return false; + // } else { + // return false; + // } +} + +/** + * @brief Parse and return the numeric GPU architecture from a Target's "arch" + * attribute. + * + * Examines the target's "arch" string and, if it matches the pattern + * "sm_", returns as an int. If the attribute is present but does not + * match that pattern, returns 0. + * + * Preconditions: the target must have an "arch" attribute (this is checked via + * ICHECK). + * + * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if + * the arch string does not match "sm_". + */ +static int GetArchInt(Target target) { + int arch_int = 0; + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); + std::string arch = s.value(); + if (arch.rfind("sm_", 0) == 0) { + arch_int = std::stoi(arch.substr(3)); + } else { + arch_int = 0; + } + return arch_int; +} + +Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + auto block_size = *as_const_int(T.thread_bounds->extent); + GemmInst gemm_inst = GetGemmInst(block_size, T.target); + + auto [warp_m, warp_n] = + policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst); + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) { + auto prim_func = + Downcast((*f)(tvm::ffi::GetRef(this), T.target, + T.thread_bounds, T.thread_var)); + ICHECK(prim_func->attrs.defined()); + auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); + ICHECK(global_symbol.has_value()); + if (prim_func->body.as()) { + BlockRealize block_realize = Downcast(prim_func->body); + auto block = block_realize->block; + { + BlockNode *n = block.CopyOnWrite(); + n->name_hint = global_symbol.value(); + } + return BlockRealize(block_realize->iter_values, block_realize->predicate, + block); + } + // warp with block realize node + return BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/global_symbol.value(), prim_func->body)); + } else { + LOG(FATAL) << "No lower function found for gemm_sp_py"; + } +} + +LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (completed_) + return {}; + LayoutMap results; + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) { + results = Downcast( + (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); + } else { + LOG(FATAL) << "No infer layout function found for gemm_sp_py"; + } + + completed_ = true; + return results; +} + +TIR_REGISTER_TL_TILE_OP(GemmSPPy, gemm_sp_py) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); } +} // namespace tl +} // namespace tvm diff --git a/src/op/gemm_sp_py.h b/src/op/gemm_sp_py.h new file mode 100644 index 000000000..2f79c5e15 --- /dev/null +++ b/src/op/gemm_sp_py.h @@ -0,0 +1,94 @@ +/*! + * \file tl/op/gemm_sp_py.h + * \brief Define gemm_sp_py operator. + * + */ + +// TODO: @botbw: remove redundant code with gemm_py.h + +#ifndef TVM_TL_OP_GEMM_SP_PY_H_ +#define TVM_TL_OP_GEMM_SP_PY_H_ + +#include "gemm_sp.h" +#include "operator.h" + +namespace tvm { + +namespace tl { + +using namespace tir; + +class GemmSPPyNode : public TileOperatorNode { +public: + bool CheckWGMMA() const; + tir::Buffer A, E, B, C; + // pointer to the A, E, B, C + BufferRegion aRegion_, eRegion_, bRegion_, cRegion_; + bool trans_A, trans_B, trans_E; + int M, N, K; + int stride_A, stride_B; + int offset_A, offset_B; + PrimExpr clear_accum = const_false(); + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack = 1; + int wg_wait = 0; + + // use GemmWarp Policy here as the atom size are flexible in v2 + mutable GemmWarpPolicy policy; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("A", &GemmSPPyNode::A) + .def_ro("E", &GemmSPPyNode::E) + .def_ro("B", &GemmSPPyNode::B) + .def_ro("C", &GemmSPPyNode::C) + .def_ro("aRegion", &GemmSPPyNode::aRegion_) + .def_ro("eRegion", &GemmSPPyNode::eRegion_) + .def_ro("bRegion", &GemmSPPyNode::bRegion_) + .def_ro("cRegion", &GemmSPPyNode::cRegion_) + .def_ro("trans_A", &GemmSPPyNode::trans_A) + .def_ro("trans_B", &GemmSPPyNode::trans_B) + .def_ro("trans_E", &GemmSPPyNode::trans_E) + .def_ro("M", &GemmSPPyNode::M) + .def_ro("N", &GemmSPPyNode::N) + .def_ro("K", &GemmSPPyNode::K) + .def_ro("stride_A", &GemmSPPyNode::stride_A) + .def_ro("stride_B", &GemmSPPyNode::stride_B) + .def_ro("offset_A", &GemmSPPyNode::offset_A) + .def_ro("offset_B", &GemmSPPyNode::offset_B) + .def_ro("clear_accum", &GemmSPPyNode::clear_accum) + .def_ro("kPack", &GemmSPPyNode::kPack) + .def_ro("wg_wait", &GemmSPPyNode::wg_wait) + .def_ro("policy", &GemmSPPyNode::policy); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + TileOperator Clone() const; + +private: + // Target GEMM instruction + GemmInst GetGemmInst(int block_size, Target target) const; + + mutable bool completed_ = false; +}; + +class GemmSPPy : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator, + GemmSPPyNode); + TVM_DLL GemmSPPy(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_GEMM_SP_PY_H_ \ No newline at end of file diff --git a/src/op/operator.h b/src/op/operator.h index 0d9f859a7..1453f9c1e 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -77,12 +77,12 @@ TileOperator ParseOperator(Stmt stmt); using OpBuilderFunc = ffi::TypedFunction)>; -#define TIR_REGISTER_TL_OP(Entry, OpName) \ +#define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \ const Op &Entry::Get() { \ - static const Op &op = Op::Get("tl." #OpName); \ + static const Op &op = Op::Get("tl.tileop." #OpName); \ return op; \ } \ - TVM_REGISTER_OP("tl." #OpName) \ + TVM_REGISTER_OP("tl.tileop." #OpName) \ .set_attr("TScriptPrinterName", #OpName) \ .set_attr( \ "TLOpBuilder", [](Array args) { return Entry(args); }) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 0d09cc129..94572098d 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -252,17 +252,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, forward_vars.push_back( IterVar(Range(0, s), Var(), IterVarType::kDataPar)); } - Array forward_index; - for (const auto &iv : forward_vars) { - forward_index.push_back(iv->var); - } Var rep; auto rep_iter = IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar); + // Use default fragment indexing (single output dim) to + // stay consistent with other ops (e.g., ReduceOp), and + // bind the thread range for comparability. const PrimExpr &forward_thread = rep; - results.Set(buffer, Fragment(forward_vars, forward_index, - forward_thread, rep_iter)); + auto frag = Fragment(forward_vars, /*forward_index=*/{}, forward_thread, + rep_iter) + ->BindThreadRange(T.thread_bounds); + results.Set(buffer, frag); } } return results; diff --git a/src/op/reduce.cc b/src/op/reduce.cc index caf9198a7..40c9b83cd 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -478,7 +478,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, return {}; } -TIR_REGISTER_TL_OP(ReduceOp, reduce) +TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -563,7 +563,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, return {}; } -TIR_REGISTER_TL_OP(CumSumOp, cumsum) +TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/region.cc b/src/op/region.cc index 2a1f27456..25e78eba8 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -76,17 +76,7 @@ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, return {}; } -const Op &RegionOp::Get() { - static const Op &op = Op::Get("tl.region"); - return op; -} - -TVM_REGISTER_OP("tl.region") - .set_attr("TScriptPrinterName", "region") - .set_attr("TLOpBuilder", - [](Array args) { - return RegionOp(args); - }) +TIR_REGISTER_TL_TILE_OP(RegionOp, region) .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index 3d994bf5c..8b6ff61ba 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -52,10 +52,8 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { } else { FAIL; } - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() || - ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() || - ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() || - ab_dtype.is_float4_e2m1fn()) && + } else if ((ab_dtype.is_float8() || ab_dtype.is_float6_e2m3fn() || + ab_dtype.is_float6_e3m2fn() || ab_dtype.is_float4_e2m1fn()) && ((c_dtype.is_float() && c_dtype.bits() == 32) || (c_dtype.is_float16() && c_dtype.bits() == 16))) { if (K % 32 != 0) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 99512b8be..65f23d8dd 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -312,7 +312,12 @@ std::string CodeGenTileLangCUDA::Finish() { void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) { if (op->kind == tir::ForKind::kUnrolled) { PrintIndent(); - stream << "#pragma unroll\n"; + if (unroll_factor.count(op->loop_var.get())) { + stream << "#pragma unroll " + << PrintExpr(unroll_factor[op->loop_var.get()]) << "\n"; + } else { + stream << "#pragma unroll\n"; + } } std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min)); @@ -2661,7 +2666,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) { this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; this->VisitStmt(op->body); return; + } else if (op->attr_key == "pragma_unroll_factor") { + const IntImmNode *factor = op->value.as(); + ICHECK(factor); + unroll_factor[op->node.as()] = Downcast(factor); } + CodeGenC::VisitStmt_(op); } diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 6f229f11d..11c0ad081 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -140,6 +140,7 @@ class CodeGenTileLangCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; + std::unordered_map unroll_factor; friend void PrintConst(const FloatImmNode *op, std::ostream &os, CodeGenTileLangCUDA *p); void PrintWmmaScope(const std::string &scope, DataType t, diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index b92fc73bf..bf2a5100b 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -127,6 +127,16 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2, return result; } +TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0, + short z1, short w0, short w1) { + int4_t result; + *((short2 *)&result.x) = make_short2(x0, x1); + *((short2 *)&result.y) = make_short2(y0, y1); + *((short2 *)&result.z) = make_short2(z0, z1); + *((short2 *)&result.w) = make_short2(w0, w1); + return result; +} + // Pack eight int values. TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, int z1, int w0, int w1) { diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index 020cb1f16..3f8ce5e6b 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -108,6 +108,16 @@ __device__ void debug_print_buffer_value(const char *msg, const char *buf_name, PrintTraits::print_buffer(msg, buf_name, index, var); } +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, uint16_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=uint16_t value=%u\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (uint32_t)var); +} + TL_DEVICE void device_assert(bool cond) { assert(cond); } TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) { diff --git a/src/transform/atomicadd_vectorize.h b/src/transform/atomicadd_vectorize.h index a55bc0f4a..627dc895f 100644 --- a/src/transform/atomicadd_vectorize.h +++ b/src/transform/atomicadd_vectorize.h @@ -11,7 +11,6 @@ #include "../op/builtin.h" #include "arith/int_operator.h" #include "arith/ir_visitor_with_analyzer.h" -#include "atomicadd_vectorize.h" #include "common/loop_vectorization_utils.h" #include #include diff --git a/testing/python/analysis/test_tilelang_nested_loop_checker.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py index b572a707a..d3c2ec20e 100644 --- a/testing/python/analysis/test_tilelang_nested_loop_checker.py +++ b/testing/python/analysis/test_tilelang_nested_loop_checker.py @@ -550,5 +550,178 @@ def test_mixed_pp(): run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1]) +""" +TiledOp in a T.Parallel is also not permitted. +""" + + +def matmul_with_parallel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (M, K) + B_shape = (K, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + for i, j in T.Parallel(block_M, block_K): + A_shared[i, j] = A[by * block_M + i, k * block_K + j] + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = B[k * block_K + i, bx * block_N + j] + + # T.copy(A[by * block_M, k * block_K], A_shared) + # T.copy(B[k * block_K, bx * block_N], B_shared) + + for _ in T.Parallel(1): + T.gemm(A_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_tiled_op_with_parallel( + order, + stage, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + in_dtype = "float16" + out_dtype = "float16" + dtypeAccum = "float32" + num_threads = 128 + + program = matmul_nested_pipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if in_dtype == "float32": + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) + B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + program1 = matmul_with_parallel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + with pytest.raises(ValueError): + tilelang.compile( + program1, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + + +@tilelang.jit(out_idx=[1]) +def tir_op_with_parallel(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = T.max(A[i * block + j], 0.0) + + return main + + +@tilelang.jit(out_idx=[1]) +def customize_op_with_parallel(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + T.atomic_add(B[i * block + j], 1.0) + + return main + + +def test_tiled_op_with_parallel(): + run_gemm_tiled_op_with_parallel(order=[0, 1, 2], stage=[0, 0, 1]) + + kernel1 = tir_op_with_parallel(length=256, block=16) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + torch.testing.assert_close(result1, torch.relu(data), atol=1e-5, rtol=1e-5) + kernel2 = customize_op_with_parallel(length=256, block=16) + result2 = kernel2(data) + torch.testing.assert_close(result2, data + 1, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py index b4509fadc..13135d416 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -52,7 +52,12 @@ def tl_matmul( micro_size_x = micro_size_y = micro_size_k = 16 - is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] + is_float8 = in_dtype in [ + "float8_e4m3", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fnuz", + ] if out_dtype == "int32" or is_float8: micro_size_k = 32 diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py index 34def174d..46f4e123a 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py @@ -51,7 +51,12 @@ def tl_matmul( micro_size_x = micro_size_y = micro_size_k = 16 - is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] + is_float8 = in_dtype in [ + "float8_e4m3", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fnuz", + ] if out_dtype == "int32" or is_float8: micro_size_k = 32 diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py index da2e12cdc..6e20754eb 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py @@ -52,7 +52,12 @@ def tl_matmul( micro_size_x = micro_size_y = micro_size_k = 16 - is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] + is_float8 = in_dtype in [ + "float8_e4m3", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fnuz", + ] if out_dtype == "int32" or is_float8: micro_size_k = 32 diff --git a/testing/python/language/test_tilelang_language_unroll.py b/testing/python/language/test_tilelang_language_unroll.py new file mode 100644 index 000000000..1796302e3 --- /dev/null +++ b/testing/python/language/test_tilelang_language_unroll.py @@ -0,0 +1,37 @@ +import tilelang.testing +from tilelang import tvm as tvm +from tilelang import language as T + + +def test_unroll_with_step(): + + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) + + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): + for i in T.unroll(0, 16, step=4): + A[0, i] = 1.0 + + kernel = tilelang.compile(main, target="cuda") + assert "#pragma unroll" in kernel.get_kernel_source() + + +def test_unroll_with_unroll_factor(): + + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) + + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): + for i in T.unroll(0, 16, unroll_factor=4): + A[0, i] = 1.0 + + kernel = tilelang.compile(main, target="cuda") + assert "#pragma unroll 4" in kernel.get_kernel_source() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index 74b9729f6..cefe986a0 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -2,28 +2,46 @@ import tilelang import tilelang.testing -from tilelang.utils.sparse import compress, randn_semi_sparse -from tilelang.layout import make_metadata_layout - -torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000) -torch.manual_seed(42) - -STR_TO_TYPE = { - 'float32': torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "float8_e4m3": torch.float8_e4m3fn, - "int8": torch.int8, - "int32": torch.int32, -} - -SPARSITY_MAP = { - # 'float32': (1, 2), # not supported for now - torch.float16: (2, 4), - torch.bfloat16: (2, 4), - torch.float8_e4m3fn: (2, 4), - torch.int8: (2, 4), -} +from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.tensor import torch_assert_close, map_torch_type +from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter + +torch.backends.cuda.matmul.allow_tf32 = False +# torch.manual_seed(42) # only enable when debugging + + +def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): + is_8bit = "8" in in_dtype + is_unsigned = "uint" in in_dtype + is_int = "int" in in_dtype + if is_int: + if is_8bit: + low, high = (0, 4) if is_unsigned else (-2, 2) + else: + low, high = (0, 128) if is_unsigned else (-64, 64) + A = randint_semi_sparse( + M, + K, + low=low, + high=high, + dtype=map_torch_type(in_dtype), + device='cuda', + transposed=trans_A) + B = torch.randint( + size=(N, K) if trans_B else (K, N), + low=low, + high=high, + dtype=map_torch_type(in_dtype), + device='cuda') + else: + A = randn_semi_sparse( + M, K, dtype=torch.float32, device='cuda', + transposed=trans_A).to(map_torch_type(in_dtype)) + B = torch.randn( + (N, K) if trans_B else (K, N), device='cuda', + dtype=torch.float32).to(map_torch_type(in_dtype)) + return A, B def matmul_sp_sm90( @@ -60,21 +78,17 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8') - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) T.annotate_layout({ E: - make_metadata_layout( - E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K), + make_cutlass_metadata_layout( + E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), E_shared: - make_metadata_layout( - E_shared, - mma_dtype="float16", - arch="9.0", - backend="cutlass", - block_k=block_K), + make_cutlass_metadata_layout( + E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K), }) T.disable_warp_group_reg_alloc() - T.clear(C_local) + T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) if trans_A: @@ -85,8 +99,8 @@ def main( T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) return main @@ -107,7 +121,8 @@ def matmul_sp_sm80( trans_B, ): is_8_bit = "8" in in_dtype - E_factor = 32 if is_8_bit else 16 + metadata_dtype = 'int32' if is_8_bit else 'int16' + E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) B_shape = (K, N) if not trans_B else (N, K) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) @@ -118,22 +133,18 @@ def matmul_sp_sm80( @T.prim_func def main( A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), 'int32' if is_8_bit else 'int16'), + E: T.Tensor((M, K // E_factor), metadata_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), - 'int32' if is_8_bit else 'int16') + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) T.annotate_layout({ - E: - make_metadata_layout(E, mma_dtype="float16", backend="cutlass", arch="8.0"), - E_shared: - make_metadata_layout( - E_shared, mma_dtype="float16", backend="cutlass", arch="8.0"), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), }) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -181,19 +192,14 @@ def run_gemm_sp( kernel, out_idx=[-1], ) - A = randn_semi_sparse(M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', transposed=trans_A) - if trans_B: - B = torch.randn((N, K), device='cuda', dtype=torch.float32) - else: - B = torch.randn((K, N), device='cuda', dtype=torch.float32) - - if "float8" in in_dtype or "int8" in in_dtype: - A = normalize(A.float()) - B = normalize(B.float()) - - A = A.to(STR_TO_TYPE[in_dtype]) - B = B.to(STR_TO_TYPE[in_dtype]) - + A, B = generate_dense_input( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + ) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) C_sp = kernel(A_sparse, E, B) @@ -206,14 +212,22 @@ def _matmul(A, B): if "float8" in in_dtype or "int8" in in_dtype: A = A.to(torch.float32) B = B.to(torch.float32) - return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype]) + return torch.matmul(A, B) C = _matmul(A, B) + if 'float8' in in_dtype: diff = calc_diff(C_sp, C) assert diff < 1e-3, f"{diff=}" else: - torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3) + torch_assert_close( + C_sp.to(torch.float32), + C.to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) print("pass") diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py new file mode 100644 index 000000000..a82c29f38 --- /dev/null +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -0,0 +1,666 @@ +from tilelang import tvm as tvm +from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse +from tilelang.utils.tensor import torch_assert_close, map_torch_type +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter + +import tilelang.testing +import torch + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_ss( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): + is_8bit = "8" in in_dtype + is_unsigned = "uint" in in_dtype + is_int = "int" in in_dtype + if is_int: + if is_8bit: + low, high = (0, 4) if is_unsigned else (-2, 2) + else: + low, high = (0, 128) if is_unsigned else (-64, 64) + A = randint_semi_sparse( + M, + K, + low=low, + high=high, + dtype=map_torch_type(in_dtype), + device='cuda', + transposed=trans_A) + B = torch.randint( + size=(N, K) if trans_B else (K, N), + low=low, + high=high, + dtype=map_torch_type(in_dtype), + device='cuda') + else: + A = randn_semi_sparse( + M, K, dtype=map_torch_type(in_dtype), device='cuda', transposed=trans_A) + B = torch.randn( + (N, K) if trans_B else (K, N), device='cuda', + dtype=torch.float32).to(map_torch_type(in_dtype)) + return A, B + + +def test_gemm_ss(): + # More test case can be found in kernel/test_tilelang_kernel_gemm.py + # GEMM tests for float16 + # TODO: support transposed A compressor + run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2) + run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2) + run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2) + run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2) + + # n8 test + run_gemm_ss(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) + + # int8 test + run_gemm_ss(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2) + run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) + + # float8 tests + run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, + 2) + run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) + + # tfloat32 test + # run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_sp_v2(A_frag, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +def test_gemm_rs(): + # GEMM tests for float16 + run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) + + # n8 tests + run_gemm_rs(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) + + # int8 tests + run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) + + # float8 tests + run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) + + # float32 tests + # run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + + +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_sp_v2(A_shared, E_shared, B_frag, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +def test_gemm_sr(): + # GEMM tests for float16 + run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) + + # n8 tests + run_gemm_sr(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) + + # int8 tests + run_gemm_sr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2) + run_gemm_sr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2) + run_gemm_sr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_sr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) + + # float8 tests + run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) + + # float32 tests + # run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + + +def matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_sp_v2(A_frag, E_shared, B_frag, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + program = matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +def test_gemm_rr(): + # GEMM tests for float16 + run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) + # n8 tests + run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2) + run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2) + + # int8 tests + run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) + run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) + + # float8 tests + run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) + + # float32 tests + # run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) + # run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/analysis/ast_printer.py b/tilelang/analysis/ast_printer.py index c54ec5cf9..e634e0271 100644 --- a/tilelang/analysis/ast_printer.py +++ b/tilelang/analysis/ast_printer.py @@ -14,7 +14,7 @@ def pre_visit(statement: tir.Stmt) -> None: Pre-order visitor to print all visited statements. """ - print(f"Visiting statement: {type(statement)}") + print(f"Visiting statement: {type(statement)}, {statement}") def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: new_body = ir_transform(func.body, pre_visit, None) diff --git a/tilelang/analysis/nested_loop_checker.py b/tilelang/analysis/nested_loop_checker.py index 7a0d94daa..eff0fc2db 100644 --- a/tilelang/analysis/nested_loop_checker.py +++ b/tilelang/analysis/nested_loop_checker.py @@ -1,6 +1,7 @@ from tvm import tir from tvm.tir import ( For, + Call, PrimFunc, PyStmtExprVisitor, ) @@ -17,6 +18,12 @@ def is_pipelined_for(op: For) -> bool: return any(key in op.annotations for key in anno_keys) +def is_tile_op(op: Call) -> bool: + """Check if a call is a tile-op""" + + return op.op.get_attr("TLOpBuilder") is not None + + @tir.functor.visitor class _NestedLoopCheckVisitor(PyStmtExprVisitor): @@ -39,7 +46,7 @@ def visit_for_(self, op: For) -> None: "Nested parallel loops are not allowed. " "Please check your loop structure.") self.in_parallel_context = True - self.visit_stmt(child) + super().visit_for_(op) self.in_parallel_context = False return elif is_pipelined_for(op): @@ -48,7 +55,14 @@ def visit_for_(self, op: For) -> None: "Pipelined loop cannot be nested inside a parallel loop. " "Please check your loop structure.") - self.visit_stmt(op.body) + super().visit_for_(op) + + def visit_call_(self, op: Call) -> None: + if self.in_parallel_context and is_tile_op(op): + raise ValueError("[Tilelang Semantic Check] " + "Only elementwise operations are allowed inside a parallel loop. " \ + f"Got a tile-op \"{op.op}\"." + ) def NestedLoopChecker(): diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 47ac888cf..9b2fca2c3 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -325,7 +325,7 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): key = self.generate_cache_key(parameters, extra_parameters) with self._lock: - if env.is_cache_enabled(): + if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): # First check in-memory cache if key in self._memory_cache: logger.warning("Found kernel in memory cache. For better performance," \ @@ -601,7 +601,7 @@ def inner(**config_arg): logger.warning("DLPack backend does not support cache saving to disk.") else: with self._lock: - if env.is_cache_enabled(): + if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): self._save_result_to_disk(key, autotuner_result) self._memory_cache[key] = autotuner_result diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 0d55cbf7d..0e6a19ba1 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -80,8 +80,8 @@ def compile_cuda(code, file_target = path_target if path_target else temp_target cmd = [get_nvcc_compiler()] cmd += [f"--{target_format}", "-O3"] - if kernels_output_dir is not None: - cmd += ["-lineinfo"] + # Always include line info for better profiling and mapping + cmd += ["-lineinfo"] if isinstance(arch, list): cmd += arch elif isinstance(arch, str): diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index dfa8050a3..1a98c8937 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -76,10 +76,8 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: # Debug # tilelang.analysis.ASTPrinter()(mod) - # Check if there are any invalid nested loops. tilelang.analysis.NestedLoopChecker()(mod) - # Check if there are any invalid symbolic T.Parallel + fragment access. tilelang.analysis.FragmentLoopChecker()(mod) diff --git a/tilelang/env.py b/tilelang/env.py index 39d9e722e..b1697ef51 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -196,12 +196,6 @@ def __set__(self, instance, value): # os.environ[self.key] = value -# Cache control API (wrap CacheState) -enable_cache = CacheState.enable -disable_cache = CacheState.disable -is_cache_enabled = CacheState.is_enabled - - # Utility function for environment variables with defaults # Assuming EnvVar and CacheState are defined elsewhere class Environment: @@ -234,13 +228,18 @@ class Environment: # Kernel Build options TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "1") # print kernel name on compile - TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set + TILELANG_DISABLE_CACHE = EnvVar( + "TILELANG_DISABLE_CACHE", + "0") # disable kernel cache, usually for unit testing / debugging, high priority + TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", + "0") # DEPRECATED! clear cache automatically if set # Kernel selection options # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0") # Auto-tuning settings + TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", "0.9") # percent of CPUs used TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", @@ -267,7 +266,7 @@ def _initialize_torch_cuda_arch_flags(self) -> None: # Cache control API (wrap CacheState) def is_cache_enabled(self) -> bool: - return CacheState.is_enabled() + return not self.is_cache_globally_disabled() and CacheState.is_enabled() def enable_cache(self) -> None: CacheState.enable() @@ -275,6 +274,12 @@ def enable_cache(self) -> None: def disable_cache(self) -> None: CacheState.disable() + def is_cache_globally_disabled(self) -> bool: + return self.TILELANG_DISABLE_CACHE.lower() in ("1", "true", "yes", "on") + + def is_autotune_cache_disabled(self) -> bool: + return self.TILELANG_AUTO_TUNING_DISABLE_CACHE.lower() in ("1", "true", "yes", "on") + def is_print_on_compilation_enabled(self) -> bool: return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on") @@ -290,6 +295,11 @@ def use_gemm_v1(self) -> bool: # Instantiate as a global configuration object env = Environment() +# Cache control API (wrap env, which is managed by CacheState and Environment Variables jointly) +enable_cache = env.enable_cache # CacheState.enable +disable_cache = env.disable_cache # CacheState.disable +is_cache_enabled = env.is_cache_enabled # CacheState.is_enabled + # Export CUDA_HOME and ROCM_HOME, both are static variables # after initialization. CUDA_HOME = env.CUDA_HOME diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index 449b6b943..f49b59569 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -151,12 +151,43 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): return row, col +def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id): + """ + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 + + row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 + groupID + 8 Otherwise + + col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4 + (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4 + """ + row = (thread_id // 4) + 8 * (local_id % 4 // 2) + col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4) + return row, col + + def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): row = 8 * (local_id // 8) + (thread_id // 4) col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4) return row, col +def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id): + """ + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 + + row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2 + (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2 + + col = groupID + """ + col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8 + row = (thread_id // 4) + 8 * (local_id // 4) + return row, col + + def shared_16x16_to_mma_32x8_smoothlayout(i, j): return (i * 2 + j // 8, j % 8) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 6e49b0582..5811eb534 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -22,8 +22,10 @@ shared_16x32_to_mma_32x16_layout_sr_b, mma_load_a_32x4_to_shared_16x8_layout, mma_load_b_32x4_to_shared_16x8_layout, + mma_load_b_32x8_to_shared_16x16_layout, mma_load_a_32x16_to_shared_16x32_layout, mma_load_b_32x16_to_shared_16x32_layout, + mma_load_a_32x8_to_shared_16x16_layout, ) lift = convert @@ -291,6 +293,8 @@ def mma_load_layout(i, j): if not ldmatrix_available: if DataType(a_dtype).bits == 8: mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout + elif DataType(a_dtype).bits == 16: + mma_load_layout = mma_load_a_32x8_to_shared_16x16_layout elif DataType(a_dtype).bits == 32: mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout else: @@ -417,6 +421,8 @@ def mma_load_layout(i, j): if not ldmatrix_available: if DataType(b_dtype).bits == 8: mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout + elif DataType(b_dtype).bits == 16: + mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout elif DataType(b_dtype).bits == 32: mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout else: diff --git a/tilelang/intrinsics/mma_sp_layout.py b/tilelang/intrinsics/mma_sp_layout.py new file mode 100644 index 000000000..bae86bf45 --- /dev/null +++ b/tilelang/intrinsics/mma_sp_layout.py @@ -0,0 +1,190 @@ +from tvm import DataType +from typing import Literal + +from tilelang.intrinsics.mma_layout import ( + mma_load_a_32x4_to_shared_16x8_layout, + mma_load_a_32x16_to_shared_16x32_layout, + mma_load_a_32x8_to_shared_16x16_layout, + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, +) + + +def shared_16x16_to_mma_sp_layout_sr_a(i, j): + return shared_16x8_to_mma_32x4_layout_sr_a(i, j) + + +def shared_16x16_to_mma_sp_layout_sr_b(i, j): + thread_id = 4 * (i % 8) + (j % 4) + return thread_id, 4 * (i // 8) + (j // 4) + + +def shared_16x32_to_mma_sp_layout_sr_a(i, j): + return shared_16x16_to_mma_32x8_layout_sr_a(i, j) + + +def shared_16x32_to_mma_sp_layout_sr_b(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 8 * (i // 8) + (j // 8) * 2 + (j % 2) + + +def shared_16x64_to_mma_sp_layout_sr_a(i, j): + return shared_16x32_to_mma_32x16_layout_sr_a(i, j) + + +def shared_16x64_to_mma_sp_layout_sr_b(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 16 * (i // 8) + (j // 16) * 4 + j % 4 + + +def mma_sp_load_a_32x4_to_shared_16x16_layout(thread_id, local_id): + return mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id) + + +def mma_sp_load_a_32x8_to_shared_16x32_layout(thread_id, local_id): + return mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id) + + +def mma_sp_load_a_32x16_to_shared_16x64_layout(thread_id, local_id): + return mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id) + + +def mma_sp_load_b_32x8_to_shared_16x16_layout(thread_id, local_id): + col = 4 * (local_id % 4) + (thread_id % 4) + row = 8 * (local_id // 4) + (thread_id // 4) + return row, col + + +def mma_sp_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): + col = (thread_id % 4) * 2 + (local_id % 2) + ((local_id % 8) // 2) * 8 + row = (thread_id // 4) + 8 * (local_id // 8) + return row, col + + +def mma_sp_load_b_32x32_to_shared_16x64_layout(thread_id, local_id): + col = (thread_id % 4) * 4 + (local_id % 4) + 16 * ((local_id % 16) // 4) + row = (thread_id // 4) + 8 * (local_id // 16) + return row, col + + +def get_logical_id_32bit(thread_id: int) -> int: + return (thread_id // 4) * 2 + (thread_id % 4) % 2 + + +def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, + local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_32bit(thread_id) + row = logical_id // 4 + local_id * 8 + col = logical_id % 4 + return row, col + + +def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, + local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_32bit(thread_id) + row = logical_id // 2 + local_id * 8 + col = logical_id % 2 + return row, col + + +def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, + local_id: int) -> tuple[int, int]: + return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit( + thread_id, local_id) # same mapping for 16bit and 32bit + + +def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, + local_id: int) -> tuple[int, int]: + return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit( + thread_id, local_id) # same mapping for 16bit and 32bit + + +def get_logical_id_8bit(thread_id: int) -> int: + return thread_id + + +def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, + local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_8bit(thread_id) + row = logical_id // 2 + local_id * 8 + col = (logical_id % 4) // 2 * 4 + local_id + return row, col + + +def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, + local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_8bit(thread_id) + row = logical_id // 2 + local_id * 8 + col = (logical_id % 4) // 2 * 2 + local_id + return row, col + + +def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, + local_id: int) -> tuple[int, int]: + # local_id is always 0 + logical_id = get_logical_id_8bit(thread_id) + row = logical_id // 4 + (logical_id % 2) * 8 + col = (logical_id % 4) // 2 + return row, col + + +def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): + row = (local_id // 4) * 8 + thread_id % 8 + col = (thread_id // 8) * 4 + local_id % 4 + return row, col + + +def ldmatrix_32x16_to_shared_32x16_layout(thread_id, local_id): + row = thread_id + col = local_id % 8 + 8 * (local_id // 8) + return row, col + + +def ldmatrix_trans_32x16_to_shared_16x32_layout(thread_id, local_id): + row = 8 * (local_id // 8) + thread_id % 8 + col = (thread_id // 8) * 8 + local_id % 8 + return row, col + + +def ldmatrix_trans_32x32_to_shared_shared_16x64_layout(thread_id, local_id): + row = (local_id // 16) * 8 + thread_id % 8 + col = (thread_id // 8) * 16 + local_id % 16 + return row, col + + +def get_ldmatrix_offset_b( + matrix: Literal["B"], + row_idx, + col_idx, + stride, + dtype: Literal["float16", "int8"] = "float16", + transposed: bool = False, +): + assert matrix == "B", "matrix should be B" + dtype_bits = DataType(dtype).bits + if dtype_bits == 32: + if transposed: + transform_func = ldmatrix_trans_32x8_to_shared_16x16_layout + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed for 32-bit dtype") + elif dtype_bits == 16: + transform_func = ldmatrix_32x16_to_shared_32x16_layout + transform_func_trans = ldmatrix_trans_32x16_to_shared_16x32_layout + if transposed: + new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif dtype_bits == 8: + if transposed: + transform_func = ldmatrix_trans_32x32_to_shared_shared_16x64_layout + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed for 8-bit dtype") + else: + raise ValueError(f"Unsupported dtype {dtype}") diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/intrinsics/mma_sp_macro_generator.py new file mode 100644 index 000000000..629d95d99 --- /dev/null +++ b/tilelang/intrinsics/mma_sp_macro_generator.py @@ -0,0 +1,864 @@ +from __future__ import annotations + +import tilelang.language as T +from typing import Literal, Callable +from tvm import DataType +from tvm.tir import PrimExpr, IndexMap, Buffer, Var +from tvm.runtime import convert +from .utils import ( + mma_store_index_map, + get_ldmatrix_offset, +) +from tilelang.utils import is_fragment + +from tilelang.intrinsics.mma_sp_layout import ( + shared_16x16_to_mma_sp_layout_sr_a, + shared_16x16_to_mma_sp_layout_sr_b, + shared_16x32_to_mma_sp_layout_sr_a, + shared_16x32_to_mma_sp_layout_sr_b, + shared_16x64_to_mma_sp_layout_sr_a, + shared_16x64_to_mma_sp_layout_sr_b, + mma_sp_load_a_32x4_to_shared_16x16_layout, + mma_sp_load_a_32x8_to_shared_16x32_layout, + mma_sp_load_a_32x16_to_shared_16x64_layout, + mma_sp_load_b_32x8_to_shared_16x16_layout, + mma_sp_load_b_32x16_to_shared_16x32_layout, + mma_sp_load_b_32x32_to_shared_16x64_layout, + metadata_8bit_load_32x4_to_shared_16x4_layout_32bit, + metadata_16bit_load_32x2_to_shared_16x2_layout_32bit, + metadata_8bit_load_32x4_to_shared_16x4_layout_16bit, + metadata_16bit_load_32x2_to_shared_16x2_layout_16bit, + metadata_8bit_load_32x4_to_shared_16x4_layout_8bit, + metadata_16bit_load_32x2_to_shared_16x4_layout_8bit, + metadata_32bit_load_32x1_to_shared_16x2_layout_8bit, + get_ldmatrix_offset_b, +) + +lift = convert + + +class SparseTensorCoreIntrinEmitter: + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + SPARSE_FACTOR = 2 # 1:2 for tfloat12, 2:4 for 16-bit and 8-bit datatypes + SPARSE_SELECTOR = 0 # always use lower threads to provide metadata + # use lowercase as n_dim can be dynamic + # the smallest instructions can be m16n8k16, so the n_dim can also be 8 + n_dim = 16 + WARP_SIZE = 32 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "float8_e4m3": "e4m3", + "float8_e5m2": "e5m2", + } + + E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor + "float": { + "int16": 8, + "uint16": 8, + }, + "float32": { + "int16": 8, + "uint16": 8, + }, + "float16": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "bfloat16": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "int8": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "uint8": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "float8_e4m3": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "float8_e5m2": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + } + + E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads + "float32": 2, + "float16": 2, # 2 of 4 consecutive threads provides + "bfloat16": 2, + "int8": 1, # 4 of 4 consecutive threads provides + "uint8": 1, + "float8_e4m3": 1, + "float8_e5m2": 1, + } + + # Represent the thread binding in the form of (tx, warp_n, warp_m) + is_m_first = False + + def __init__( + self, + a_dtype: str = "float16", + e_dtype: str = "uint8", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + e_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + warp_k: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool = False, + thread_var: Var | None = None, + ): + self.a_dtype = a_dtype + self.e_dtype = e_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + self.e_transposed = e_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.warp_k = warp_k + self.e_factor = self.E_FACTOR_MAP[self.a_dtype][self.e_dtype] + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_micro_size(self.M_DIM, self.k_dim) + self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) + self._initialize_mma_sp_prefix(self.k_dim) + self._initialize_is_m_first(is_m_first) + + self.reduce_k = reduce_k + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var + + if self.warp_rows == 0 or self.warp_cols == 0: + raise ValueError( + f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" + ) + + def _initialize_k_dim(self, a_dtype="float16"): + if isinstance(a_dtype, str): + a_dtype = DataType(a_dtype) + # NOTE: k_dim here represents the logical shape of the MMA operation. + # When referring to the physical data movement, it should be divided by sparse_factor. + self.k_dim = 256 // a_dtype.bits * self.SPARSE_FACTOR + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR + self.local_size_e = ( + m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype] + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + + def _initialize_mma_sp_prefix(self, k_dim: int = 16): + if k_dim == 16: + # typically used for tfloat32 + self.mma_prefix = "m16n8k16" + elif k_dim == 32: + # typically used for float16/bfloat16 + self.mma_prefix = "m16n8k32" + elif k_dim == 64: + # typically used for int8/fp8 + self.mma_prefix = "m16n8k64" + else: + raise ValueError("Unsupported k_dim") + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + self.warp_rows = warp_row_tiles // m_dim + + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + + self.micro_size_x = m_dim + # NOTE: k_dim here represents the logical shape of the MMA operation. + self.micro_size_k = k_dim + + def _initialize_is_m_first(self, is_m_first: bool | None = False): + if is_m_first is not None: + self.is_m_first = is_m_first + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def get_store_index_map(self, inverse: bool = False) -> IndexMap: + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out + index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") + if not inverse: + return index_map + inverse_index_map = index_map.inverse([warp_size, local_size_c]) + return inverse_index_map + + def extract_thread_binding( + self, + thread_id: PrimExpr, + is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) + return lane_id, warp_n, warp_m + + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + warp_k = self.warp_k + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + a_dtype = self.a_dtype + a_transposed = self.a_transposed + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed) + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(a_dtype).bits == 8: + mma_load_layout = mma_sp_load_a_32x16_to_shared_16x64_layout + elif DataType(a_dtype).bits == 16: + mma_load_layout = mma_sp_load_a_32x8_to_shared_16x32_layout + elif DataType(a_dtype).bits == 32: + mma_load_layout = mma_sp_load_a_32x4_to_shared_16x16_layout + else: + raise ValueError(f"Unsupported dtype: {a_dtype}") + + thread_binding = self.get_thread_binding() + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = A_shared_buf.shape[-1] + tx, _, warp_m = self.extract_thread_binding(thread_binding) + trans = self.a_transposed + + for i in T.serial(warp_rows): + # Assign A_shared_buf_elem + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( + rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR + A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] + + if ldmatrix_available: + T.ptx_ldmatrix( + a_dtype, + T.bool(trans), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_buf_elem), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + else: + for j in T.serial(local_size_a): + mi, mk = mma_load_layout(tx, j) + A_local_buf[i * local_size_a + + j] = A_shared_buf[wk + mk, wi + + mi] if a_transposed else A_shared_buf[wi + mi, + wk + mk] + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) + + def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + warp_k = self.warp_k + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_e = self.local_size_e + a_dtype = self.a_dtype + e_dtype = self.e_dtype + trans = self.e_transposed + # ldmatrix cannot be used for int8 + trans case. + # include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h + ldmatrix_available = False # TODO: use ldmatrix when possible + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(e_dtype).bits == 8: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_8bit + elif DataType(a_dtype).bits == 16: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_16bit + elif DataType(a_dtype).bits == 32: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_32bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 8bit: {a_dtype}") + elif DataType(e_dtype).bits == 16: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x4_layout_8bit + elif DataType(a_dtype).bits == 16: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_16bit + elif DataType(a_dtype).bits == 32: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_32bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 16bit: {a_dtype}") + elif DataType(e_dtype).bits == 32: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_32bit_load_32x1_to_shared_16x2_layout_8bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 32bit: {a_dtype}") + else: + raise ValueError(f"Unsupported dtype: {e_dtype}") + + thread_binding = self.get_thread_binding() + + @T.macro + def _warp_ldmatrix_e( + E_local_buf, + E_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + for i in T.serial(warp_rows): + # Assign E_shared_buf_elem + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( + rk * warp_k + ki * micro_size_k) // self.e_factor + for j in T.serial(local_size_e): + mi, mk = mma_load_layout(tx, j) + E_local_buf[i * local_size_e + + j] = E_shared_buf[wk + mk, + wi + mi] if trans else E_shared_buf[wi + mi, + wk + mk] + + return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + warp_k = self.warp_k + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + b_dtype = self.b_dtype + b_transposed = self.b_transposed + thread_binding = self.get_thread_binding() + replicate_b = (self.n_dim == 16) + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(b_dtype).bits == 8: + mma_load_layout = mma_sp_load_b_32x32_to_shared_16x64_layout + elif DataType(b_dtype).bits == 16: + mma_load_layout = mma_sp_load_b_32x16_to_shared_16x32_layout + elif DataType(b_dtype).bits == 32: + mma_load_layout = mma_sp_load_b_32x8_to_shared_16x16_layout + else: + raise ValueError(f"Unsupported dtype: {b_dtype}") + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = B_shared_buf.shape[-1] + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + trans = not b_transposed + + for i in T.serial(warp_cols): + # Assign B_shared_elem + wi, wk = ( + warp_n * warp_col_tiles + i * micro_size_y, + rk * warp_k + ki * micro_size_k, + ) + + if ldmatrix_available: + B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, + wi] + + if replicate_b: + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4, + ".b16", + B_local_buf.data, + i * local_size_b, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed), + ) + + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4, + ".b16", + B_local_buf.data, + i * local_size_b + lift(local_size_b) // 2, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset_b("B", tx, + lift(local_size_b) // 2, stride, b_dtype, + b_transposed), + ) + else: + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4, + ".b16", + B_local_buf.data, + i * local_size_b, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed), + ) + + else: + # load 16x32 data from shared buffer to local buffer + # must be transposed. + for j in T.serial(local_size_b): + mi, mk = mma_load_layout(tx, j) + B_local_buf[i * local_size_b + + j] = B_shared_buf[wi + mi, wk + + mk] if b_transposed else B_shared_buf[wk + mk, + wi + mi] + + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) + + def mma_sp(self, + A_local_buf: Buffer, + E_local_buf: Buffer, + B_local_buf: Buffer, + C_local_buf: Buffer, + k_inner: PrimExpr = 0): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_e = self.local_size_e + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + replicate_b = (self.n_dim == 16) + + a_is_fragment = is_fragment(A_local_buf) + e_is_fragment = is_fragment(E_local_buf) + b_is_fragment = is_fragment(B_local_buf) + assert not e_is_fragment, f"currently E_local_buf must be a local allocation, found {E_local_buf.scope()}" + a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 + e_local_stride: PrimExpr = k_inner * warp_rows * local_size_e if e_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 + + @T.macro + def _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.ptx_mma_sp( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + E_local_buf.data, # metadata + e_local_stride + i * local_size_e, # metadata offset + self.SPARSE_SELECTOR, # sparse_selector + T.bool(False), # saturate + ) + if replicate_b: + T.ptx_mma_sp( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + + lift(local_size_out) // 2, + E_local_buf.data, # metadata + e_local_stride + i * local_size_e, # metadata offset + self.SPARSE_SELECTOR, # sparse_selector + T.bool(False), # saturate + ) + + return _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf) + + def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_out = self.local_size_out + + is_global = pid_m is not None and pid_n is not None + BLOCK_M = block_row_warps * warp_rows + BLOCK_N = block_col_warps * warp_cols + M_DIM, n_dim = self.M_DIM, self.n_dim + C_buf_dims = len(C_buf.shape) + assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" + + thread_binding = self.get_thread_binding() + + # STS + # MMA Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @T.macro + def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + if C_buf_dims == 2: + C_buf[(warp_m * warp_rows + i) * M_DIM + row, + (warp_n * warp_cols + j) * n_dim + + col] = C_local_buf[i * (warp_cols * local_size_out) + + j * local_size_out + local_id] + else: + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, + col] = C_local_buf[i * (warp_cols * local_size_out) + + j * local_size_out + local_id] + + @T.macro + def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, + (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + + local_id] + + return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)) + + def make_mma_load_layout(self, + local_buf: Buffer, + matrix: Literal["A", "B"] = "A") -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" + matrix_is_a: bool = matrix == "A" + matrix_is_b: bool = matrix == "B" + dtype = self.a_dtype if matrix_is_a else self.b_dtype + dtype_bits = DataType(dtype).bits + transposed = self.a_transposed if matrix_is_a else self.b_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x16_to_mma_sp_layout_sr_a + transform_func_sr_b = shared_16x16_to_mma_sp_layout_sr_b + elif dtype_bits == 16: + transform_func_sr_a = shared_16x32_to_mma_sp_layout_sr_a + transform_func_sr_b = shared_16x32_to_mma_sp_layout_sr_b + elif dtype_bits == 8: + transform_func_sr_a = shared_16x64_to_mma_sp_layout_sr_a + transform_func_sr_b = shared_16x64_to_mma_sp_layout_sr_b + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix_is_a and not transposed) + is_sr_conditions.append(matrix_is_b and transposed) + is_sr_axis_order = any(is_sr_conditions) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix_is_a: + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( + j, i) + elif matrix_is_b: + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( + j, i) + else: + raise ValueError(f"Unsupported matrix {matrix}") + + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + if matrix_is_a: + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + else: + micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] if is_sr_axis_order + else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows, warp_cols = self.warp_rows, self.warp_cols + chunk = self.warp_k + + warp_s = warp_rows if matrix_is_a else warp_cols + warp_r = chunk // micro_size_r + block_s = block_row_warps if matrix_is_a else block_col_warps + replicate = block_col_warps if matrix_is_a else block_row_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([warp_s, warp_r], + repeat_on_thread=False, + lower_dim_first=False) + if matrix_is_a: + block_fragment = warp_fragment.repeat([block_s, 1], + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], + repeat_on_thread=True, + lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + else: + warp_fragment = base_fragment.repeat([warp_r, warp_s], + repeat_on_thread=False, + lower_dim_first=True) + if matrix_is_a: + block_fragment = warp_fragment.repeat([1, block_s], + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], + repeat_on_thread=True, + lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + + return block_fragment + + def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + shape = local_buf.shape + inverse_mma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + local_size_out = self.local_size_out + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + warp_size = self.WARP_SIZE + is_m_first = self.is_m_first + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols + block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + if is_m_first: + thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id + else: + thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id + return thread_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of warp_i and warp_j are warp_rows and warp_cols + warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + _, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id + + return T.Fragment( + shape, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) diff --git a/tilelang/ir.py b/tilelang/ir.py index cccf97e0a..08d4e96cd 100644 --- a/tilelang/ir.py +++ b/tilelang/ir.py @@ -39,6 +39,19 @@ def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target return self.m_warp, self.n_warp +@tvm_ffi.register_object("tl.GemmSPWarpPolicy") +class GemmSPWarpPolicy(Node, Scriptable): + policy_type: int + m_warp: int + n_warp: int + + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, + is_wgmma: bool, bits: int): + _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, + is_wgmma, bits) + return self.m_warp, self.n_warp + + @tvm_ffi.register_object("tl.Gemm") class Gemm(Node, Scriptable): ... diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 75d8d0b4f..1f560a446 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -24,7 +24,15 @@ LocalBuffer, # noqa: F401 Ref, # noqa: F401 ) -from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401 +from .loop import ( + Parallel, # noqa: F401 + Persistent, # noqa: F401 + Pipelined, # noqa: F401 + serial, # noqa: F401 + unroll, # noqa: F401 + Serial, # noqa: F401 + Unroll, # noqa: F401 +) from .frame import has_let_value, get_let_value # noqa: F401 from .math_intrinsics import * # noqa: F401 from .kernel import ( @@ -51,7 +59,7 @@ ) from .copy import copy, c2d_im2col # noqa: F401 from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401 -from .experimental.gemm_sp import gemm_sp # noqa: F401 +from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 from .fill import fill, clear # noqa: F401 from .reduce import ( reduce, # noqa: F401 diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 56f87473f..07e45bbc8 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -212,9 +212,9 @@ def get_extent(data): "return_prev is not supported for tile-region-based atomic operations") if memory_order is None: - return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, 0) + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0) else: - return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order]) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 965919fd4..cabc4a3e4 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -90,7 +90,7 @@ def get_extent(data): eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy) @@ -124,5 +124,5 @@ def c2d_im2col(img: tir.Buffer, eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] img_region = to_buffer_region(img, access_type="r") col_region = to_buffer_region(col, access_type="w") - return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img_region, col_region, + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region, nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy) diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py index 7cc3d736d..4a20f3fb6 100644 --- a/tilelang/language/experimental/gemm_sp.py +++ b/tilelang/language/experimental/gemm_sp.py @@ -3,7 +3,15 @@ from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T from tvm import tir -from tilelang.utils.language import to_buffer_region +from tilelang.utils.language import ( + to_buffer_region, + retrieve_shape, + retrieve_stride, + retrieve_offset, + prim_expr_equal, +) +from tilelang.language.utils import ( + buffer_region_to_tile_region,) def gemm_sp( @@ -70,7 +78,7 @@ def legalize_arguments(arg: tir.Buffer | tir.Var): C_arg = to_buffer_region(C, access_type="rw") return tir.call_intrin( "handle", - tir.op.Op.get("tl.gemm_sp"), + tir.op.Op.get("tl.tileop.gemm_sp"), A_arg, E_arg, B_arg, @@ -85,3 +93,128 @@ def legalize_arguments(arg: tir.Buffer | tir.Var): k_pack, wg_wait, ) + + +# experimental currently, for fast compilation +def gemm_sp_v2( + A_sparse: tir.Buffer | tir.Var, + E: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, + transpose_A: bool = False, + transpose_B: bool = False, + transpose_E: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, +): + """Perform a General Matrix Multiplication (GEMM) operation. + + This function computes C = A @ B where A and B can optionally be transposed. + The operation supports various warp policies and accumulation modes. + + Args: + A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements + E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E + B (Union[tir.Buffer, tir.Var]): Second input matrix + C (Union[tir.Buffer, tir.Var]): Output matrix for results + transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. + transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. + clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. + k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. + wg_wait (int, optional): Warp group wait count. Defaults to 0. + + Returns: + tir.Call: A handle to the GEMM operation + + Raises: + AssertionError: If the K dimensions of matrices A and B don't match + """ + + def legalize_arguments(arg: tir.Buffer | tir.Var): + """Convert let-bound variables to their corresponding buffers. + + Args: + arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + + Returns: + Union[tir.Buffer, tir.Var]: The legalized argument + """ + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A_sparse = legalize_arguments(A_sparse) + E = legalize_arguments(E) + B = legalize_arguments(B) + C = legalize_arguments(C) + + A_region = to_buffer_region(A_sparse) + E_region = to_buffer_region(E) + B_region = to_buffer_region(B) + C_region = to_buffer_region(C) + + A_shape = retrieve_shape(A_sparse) + E_shape = retrieve_shape(E) # nolint: F841 + B_shape = retrieve_shape(B) + C_shape = retrieve_shape(C) + + A_stride = retrieve_stride(A_sparse) + B_stride = retrieve_stride(B) + + assert len(C_shape) == 2, "current only support C as a 2D tensor" + assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" + assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" + if len(A_shape) > 2: + for i in range(len(A_shape) - 2): + assert A_shape[i] == 1, \ + "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + if len(B_shape) > 2: + for i in range(len(B_shape) - 2): + assert B_shape[i] == 1, \ + "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + + M, N = C_shape + K = 2 * (A_shape[-2] if transpose_A else A_shape[-1]) + K_B = B_shape[-1] if transpose_B else B_shape[-2] + assert prim_expr_equal( + K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" + + stride_a = A_stride[-2] + stride_b = B_stride[-2] + + A_offset = retrieve_offset(A_sparse) + B_offset = retrieve_offset(B) + assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" + assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" + offset_a = A_offset[-1] + offset_b = B_offset[-1] + + A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) + E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape]) + B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.gemm_sp_py"), + A_arg, + E_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + transpose_E, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + k_pack, + wg_wait, + ) diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index fbbcf1b63..b23733377 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -32,7 +32,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim extents = [tir.IntImm("int32", 1) for _ in buffer.indices] else: extents = [] - return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), to_buffer_region(buffer, access_type="w", extents=extents), value) diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index 2bfd3a0cf..db8e04aba 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -116,7 +116,7 @@ def gemm_v1( ): """GEMM v1: use op tl.gemm.""" return _gemm_impl( - "tl.gemm", + "tl.tileop.gemm", A, B, C, @@ -145,7 +145,7 @@ def gemm_v2( ): """GEMM v2: use op tl.gemm_py.""" return _gemm_impl( - "tl.gemm_py", + "tl.tileop.gemm_py", A, B, C, diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index 4f8d5c307..3478b6cc1 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -4,8 +4,9 @@ from tvm import tir from tvm.tir import IntImm import tvm.script.ir_builder.tir as tb_tir -from .v2.builder import SerialForWithStep +from .v2.builder import SerialForWithStep, UnrollForWithStep from tilelang import _ffi_api +from tvm.script.ir_builder.tir import frame def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): @@ -97,7 +98,7 @@ def serial(start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, - annotations: dict[str, Any] | None = None): + annotations: dict[str, Any] | None = None) -> frame.ForFrame: step_is_one = False step_is_one |= isinstance(step, int) and step == 1 step_is_one |= isinstance(step, IntImm) and step.value == 1 @@ -108,3 +109,70 @@ def serial(start: tir.PrimExpr, stop = start start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 return SerialForWithStep(start, stop, step, annotations=annotations) + + +def unroll(start: tir.PrimExpr, + stop: tir.PrimExpr | None = None, + step: tir.PrimExpr | None = None, + *, + explicit: bool = False, + unroll_factor: int | None = None, + annotations: dict[str, Any] | None = None) -> frame.ForFrame: + """The unrolled For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + step : PrimExpr + The step size of the iteration. + + explicit : bool + Whether to explicitly unroll the loop. + + unroll_factor : int + The unroll factor of the loop. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + + step_is_one = False + if stop is None: + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + + # Ensure annotations has {"pragma_unroll_explicit": True} by default + if annotations is None: + annotations = {"pragma_unroll_explicit": explicit} + else: + # Add "pragma_unroll_explicit": True if not already present + annotations = dict(annotations) + annotations.setdefault("pragma_unroll_explicit", explicit) + + if unroll_factor is not None: + # check pragma_unroll_explicit must be False + if annotations.get("pragma_unroll_explicit", True): + raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None") + annotations.update({"pragma_unroll_factor": unroll_factor}) + + if step is None or step_is_one: + return tb_tir.unroll(start, stop, annotations=annotations) + else: + return UnrollForWithStep(start, stop, step, annotations=annotations) + + +Serial = serial +Unroll = unroll diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 3c4d8187b..fb84b6d78 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -13,6 +13,9 @@ def _legalize_dim(buffer: tir.Buffer, dim: int): return dim +_REDUCE_OP_KEY = "tl.tileop.reduce" + + def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): """Perform a reduction operation on a buffer along a specified dimension. @@ -50,7 +53,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int copy(buffer, red_frag_in) tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, @@ -65,7 +68,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int copy(buffer, red_frag_in) tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(out, access_type="w"), reduce_type, @@ -78,7 +81,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(buffer, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, @@ -89,7 +92,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int elif is_fragment(buffer) and is_fragment(out): tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(buffer, access_type="r"), to_buffer_region(out, access_type="w"), reduce_type, @@ -245,7 +248,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - copy(src, cumsum_smem) tir.call_intrin( "handle", - tir.op.Op.get("tl.cumsum"), + tir.op.Op.get("tl.tileop.cumsum"), to_buffer_region(cumsum_smem, access_type="r"), to_buffer_region(cumsum_smem, access_type="w"), dim, @@ -299,7 +302,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse return cumsum_fragment(src, dst, dim, reverse) return tir.call_intrin( "handle", - tir.op.Op.get("tl.cumsum"), + tir.op.Op.get("tl.tileop.cumsum"), to_buffer_region(src, access_type="r"), to_buffer_region(dst, access_type="w"), dim, @@ -309,7 +312,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse def finalize_reducer(reducer: tir.Buffer): """ - Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic. + Finalize a reducer buffer by emitting the `tl.tileop.finalize_reducer` intrinsic. This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer. The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR. @@ -322,7 +325,7 @@ def finalize_reducer(reducer: tir.Buffer): """ return tir.call_intrin( "handle", - tir.op.Op.get("tl.finalize_reducer"), + tir.op.Op.get("tl.tileop.finalize_reducer"), to_buffer_region(reducer, access_type="w"), ) diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 75fea4c09..136bc0bac 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -7,7 +7,7 @@ def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): """Create a tl.region call for a BufferLoad and extents.""" access_type = {"r": 1, "w": 2, "rw": 3}[access_type] - return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args) + return T.call_intrin("handle", op.Op.get("tl.tileop.region"), buffer, access_type, *args) def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]): diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index aea425adc..68c109133 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -112,6 +112,11 @@ class SerialForWithStep: annotations: dict[str, Any] | None = None +@dataclass +class UnrollForWithStep(SerialForWithStep): + ... + + # Python 3.9 compatibility: avoid PEP 604 unions at runtime # Use tuple for isinstance checks and typing.Union for annotations/aliases ContinueOrBreak = (ContinueFrame, BreakFrame) @@ -270,7 +275,7 @@ def eval(self, val: Any): def ctx_for(self, it): self.check_continue_break() it = unwrap_expr(it) - if isinstance(it, SerialForWithStep): + if isinstance(it, (SerialForWithStep, UnrollForWithStep)): # Validate and compute the trip count before constructing the frame if isinstance(it.step, (int, IntImm)): step_value = it.step if isinstance(it.step, int) else it.step.value @@ -285,7 +290,14 @@ def ctx_for(self, it): f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang' ) real_stop = tir.ceildiv(it.stop - it.start, it.step) - real_frame = tir.serial(real_stop, annotations=it.annotations) + if isinstance(it, UnrollForWithStep): + real_frame = tir.unroll(real_stop, annotations=it.annotations) + elif isinstance(it, SerialForWithStep): + real_frame = tir.serial(real_stop, annotations=it.annotations) + else: + raise TypeError( + f"Invalid for loop, got {it}({type(it)}), expect one of the following: " + "range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding") with self.with_frame(real_frame) as v: IRBuilder.name('_tmp', v) yield it.start + v * it.step diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index ee513257f..777802d2c 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -13,4 +13,4 @@ make_quarter_bank_swizzled_layout, # noqa: F401 make_linear_layout, # noqa: F401 ) -from .gemm_sp import make_metadata_layout # noqa: F401 +from .gemm_sp import make_cutlass_metadata_layout # noqa: F401 diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py index eaaa178f5..e5d190292 100644 --- a/tilelang/layout/gemm_sp.py +++ b/tilelang/layout/gemm_sp.py @@ -17,7 +17,7 @@ def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]: return res -def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int): +def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int): """Make a layout of metadata that is compatible with cutlass sm90 compression kernel. Note that layout atom is the same for smem and gmem. Args: @@ -30,7 +30,7 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b block_k = 128 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2) - if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8"]: + if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8_e4m3", "float8_e5m2"]: raise NotImplementedError(f"Unsupported dtype: {mma_dtype}") if buffer.dtype not in ["uint8", "int8"]: @@ -41,7 +41,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b "bfloat16": 16, "float32": 32, "int8": 8, - "float8": 8, + "float8_e4m3": 8, + "float8_e5m2": 8, } # ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117 @@ -75,8 +76,8 @@ def gen_stride(shape_ik, order): shape_i, shape_k = shape_ik[:3], shape_ik[3:] stride_i, stride_k = stride_ik[:3], stride_ik[3:] elif bits_map[mma_dtype] == 8: - shape_i, shape_k = [64], [BlockK] - stride_i, stride_k = [BlockK], [1] + shape_i, shape_k = [64], [block_k // 8] + stride_i, stride_k = [block_k // 8], [1] else: raise NotImplementedError(f"Unknown mma type {mma_dtype}") @@ -103,54 +104,48 @@ def transform(i: int, k: int) -> int: return T.Layout(shape, transform) -def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str): +def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): """Make a layout of metadata that is compatible with cutlass sm8x compression kernel. Note that layout atom is the same for smem and gmem. - + ref: https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/sparse/_semi_structured_conversions.py#L5 Args: buffer: metadata buffer shape, for sm80 it should be a 16bit type """ - # ref: https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h#L651 - # https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/matrix.h#L405 - # https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/warp/mma_sparse_tensor_op.h#L172 - if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]: raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") - if mma_dtype in ["float8", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]: + if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8" + ] and buffer.dtype not in ["uint32", "int32"]: raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") - kInterleaved = 2 - stride = buffer.shape[0] * kInterleaved + m, k = buffer.shape + group = 32 if buffer.dtype.bits == 16 else 16 + interweave = 4 if buffer.dtype.bits == 16 else 2 def ColumnMajorInterleaved(i: int, j: int) -> int: - column_major = j // kInterleaved - column_minor = j % kInterleaved - return column_major * stride + i * kInterleaved + column_minor + i = i // group * group + (i % 8) * interweave + (i % group) // 8 + topright = (1 - (i % 2)) & (j % 2) + bottomleft = (i % 2) & (1 - (j % 2)) + i += topright - bottomleft + j -= topright - bottomleft + offset = (j // 2) * m * 2 + i * 2 + (j % 2) + return offset // k, offset % k return T.Layout(buffer.shape, ColumnMajorInterleaved) -def make_metadata_layout(buffer: tvm.tir.Buffer, - mma_dtype: str = "float16", - backend: str = 'cutlass', - arch: str | None = None, - **extra_args): +def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, + mma_dtype: str = "float16", + arch: str | None = None, + **extra_args): if arch is None: arch = nvcc.get_target_compute_version() compute_version = nvcc.parse_compute_version(arch) if compute_version >= (9, 0): - if backend == 'cutlass': - return _make_metadata_layout_sm90_cutlass( - buffer=buffer, mma_dtype=mma_dtype, **extra_args) - else: - raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}") + return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args) elif compute_version >= (8, 0): - if backend == 'cutlass': - return _make_metadata_layout_sm8x_cutlass(buffer=buffer, mma_dtype=mma_dtype) - else: - raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}") + return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype) else: raise NotImplementedError(f"Unsupported architecture: {arch}") diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index 5af1fc2bf..4750fa7d5 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -10,6 +10,7 @@ get_tensor_supply, TensorSupplyType, torch_assert_close, + is_float8_dtype, ) from tilelang.engine.param import KernelParam from tilelang.jit.adapter import BaseKernelAdapter @@ -125,17 +126,9 @@ def assert_allclose( if lhs is not None and rhs is not None: # in case of numsplit template, the ref output may be None # which means the value is invalid, so we skip the comparison - def is_float8(tensor: torch.Tensor) -> bool: - return tensor.dtype in { - torch.float8_e5m2, - torch.float8_e5m2fnuz, - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - } - torch_assert_close( - lhs if not is_float8(lhs) else lhs.to(torch.float32), - rhs if not is_float8(rhs) else rhs.to(torch.float32), + lhs if not is_float8_dtype(lhs.dtype) else lhs.to(torch.float32), + rhs if not is_float8_dtype(rhs.dtype) else rhs.to(torch.float32), rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio, diff --git a/tilelang/tileop/__init__.py b/tilelang/tileop/__init__.py index 5656494fe..a99cbd878 100644 --- a/tilelang/tileop/__init__.py +++ b/tilelang/tileop/__init__.py @@ -1 +1,2 @@ from .gemm import GemmPy # noqa: F401 +from .gemm_sp import GemmSPPy # noqa: F401 diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 4c6762450..90960904f 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -3,6 +3,7 @@ from tvm import tir from tvm.target import Target from tvm.ir.base import Node +from tvm.ir import Range from tvm.runtime import Scriptable import tvm_ffi from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy @@ -16,13 +17,14 @@ @tvm_ffi.register_global_func("tl.gemm_py.infer_layout") -def gemm_py_infer_layout(gemm_py, target, thread_bounds): +def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range): thread_nums = thread_bounds.extent return gemm_py.infer_layout(target, thread_nums) @tvm_ffi.register_global_func("tl.gemm_py.lower") -def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): +def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, + thread_var: tir.Var): thread_nums = thread_bounds.extent stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) return stmt diff --git a/tilelang/tileop/gemm_sp/__init__.py b/tilelang/tileop/gemm_sp/__init__.py new file mode 100644 index 000000000..fdac694ce --- /dev/null +++ b/tilelang/tileop/gemm_sp/__init__.py @@ -0,0 +1,69 @@ +from tilelang import tvm as tvm +from tvm import tir +from tilelang.utils.target import ( + target_is_cuda,) +from tvm.target import Target +from tvm.ir.base import Node +from tvm.ir import Range +from tvm.runtime import Scriptable +import tvm_ffi +from tilelang.ir import GemmWarpPolicy +from .gemm_sp_mma import GemmSPMMA + + +@tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout") +def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range): + thread_nums = thread_bounds.extent + return gemm_sp_py.infer_layout(target, thread_nums) + + +@tvm_ffi.register_global_func("tl.gemm_sp_py.lower") +def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, + thread_var: tir.Var): + thread_nums = thread_bounds.extent + stmt = gemm_sp_py.lower(target, thread_nums, thread_var) + return stmt + + +@tvm_ffi.register_object("tl.GemmSPPy") +class GemmSPPy(Node, Scriptable): + A: tir.Buffer + E: tir.Buffer + B: tir.Buffer + C: tir.Buffer + + APtr: tir.PrimExpr + EPtr: tir.PrimExpr + BPtr: tir.PrimExpr + CPtr: tir.PrimExpr + + M: int + N: int + K: int + + trans_A: bool + trans_B: bool + + stride_A: int + stride_B: int + offset_A: int + offset_B: int + clear_accum: bool + k_pack: int + wg_wait: int + policy: GemmWarpPolicy + + def infer_layout(self, target: Target, thread_nums: int): + if target_is_cuda(target): + # TODO(lei): Support more cuda architectures, now mma only + return GemmSPMMA(self).infer_layout(target, thread_nums) + else: + raise ValueError(f"Unsupported target: {target}") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + if target_is_cuda(target): + # TODO(lei): Support more cuda architectures, now mma only + # Now only implement ssr layout + return GemmSPMMA(self).lower(target, thread_nums, thread_var) + else: + raise ValueError(f"Unsupported target: {target}") diff --git a/tilelang/tileop/gemm_sp/gemm_sp_base.py b/tilelang/tileop/gemm_sp/gemm_sp_base.py new file mode 100644 index 000000000..51c6786b4 --- /dev/null +++ b/tilelang/tileop/gemm_sp/gemm_sp_base.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang.utils.language import is_shared, is_fragment +from tilelang.ir import GemmWarpPolicy +from tvm.ir.base import Node + + +@dataclass +class GemmSPBase: + gemm_sp_node: Node + + def infer_layout(self, target: Target, thread_nums: int): + raise NotImplementedError("infer_layout is not implemented") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + raise NotImplementedError("lower is not implemented") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) + + @property + def M(self) -> int: + return self.gemm_sp_node.M + + @property + def N(self) -> int: + return self.gemm_sp_node.N + + @property + def K(self) -> int: + return self.gemm_sp_node.K + + @property + def trans_A(self) -> bool: + return self.gemm_sp_node.trans_A + + @property + def trans_B(self) -> bool: + return self.gemm_sp_node.trans_B + + @property + def trans_E(self) -> bool: + return self.gemm_sp_node.trans_E + + @property + def e_dtype(self) -> str: + return self.E.dtype + + @property + def in_dtype(self) -> str: + assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" + return self.A.dtype + + @property + def accum_dtype(self) -> str: + return self.C.dtype + + @property + def A(self) -> tir.Buffer: + return self.gemm_sp_node.A + + @property + def E(self) -> tir.Buffer: + return self.gemm_sp_node.E + + @property + def B(self) -> tir.Buffer: + return self.gemm_sp_node.B + + @property + def C(self) -> tir.Buffer: + return self.gemm_sp_node.C + + @property + def ARegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.ARegion + + @property + def ERegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.ERegion + + @property + def BRegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.BRegion + + @property + def CRegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.CRegion + + @property + def stride_A(self) -> int: + return self.gemm_sp_node.stride_A + + @property + def stride_B(self) -> int: + return self.gemm_sp_node.stride_B + + @property + def offset_A(self) -> int: + return self.gemm_sp_node.offset_A + + @property + def offset_B(self) -> int: + return self.gemm_sp_node.offset_B + + @property + def clear_accum(self) -> bool: + return self.gemm_sp_node.clear_accum + + @property + def k_pack(self) -> int: + return self.gemm_sp_node.k_pack + + @property + def wg_wait(self) -> int: + return self.gemm_sp_node.wg_wait + + @property + def policy(self) -> GemmWarpPolicy: + return self.gemm_sp_node.policy diff --git a/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/tileop/gemm_sp/gemm_sp_mma.py new file mode 100644 index 000000000..50a40bb91 --- /dev/null +++ b/tilelang/tileop/gemm_sp/gemm_sp_mma.py @@ -0,0 +1,247 @@ +from .gemm_sp_base import GemmSPBase +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.utils.language import is_shared, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmSPMMA(GemmSPBase): + + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = SparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + e_dtype=self.e_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + ) + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_sr(): + return { + self.A: make_swizzled_layout(self.A), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rr(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError( + f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = SparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + e_dtype=self.e_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols + local_size_a = mma_emitter.local_size_a + local_size_e = mma_emitter.local_size_e + local_size_b = mma_emitter.local_size_b + micro_size_k = mma_emitter.micro_size_k + A_shared = self.A + E_shared = self.E + B_shared = self.B + C_local = self.C + assert micro_size_k <= self.K, f"K dimension {self.K} should be >= micro size k {micro_size_k}" + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_sr(): + B_local = self.B + + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parent block + return _Simplify(_gemm_srr, inline_let=True) + elif self.is_gemm_rs(): + A_local = self.A + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + elif self.is_gemm_rr(): + A_local = self.A + B_local = self.B + + @T.prim_func + def _gemm_rrr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rrr, inline_let=True) + else: + raise ValueError( + f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py index cd364b8bb..a7b17ad93 100644 --- a/tilelang/utils/sparse.py +++ b/tilelang/utils/sparse.py @@ -3,6 +3,7 @@ import torch import warnings from tilelang.contrib import nvcc +from tilelang.utils.tensor import is_float8_dtype, fp8_remove_negative_zeros_ from torch.utils.cpp_extension import load, _import_module_from_library from tilelang import env @@ -15,11 +16,10 @@ def _get_cached_lib(): name = 'compress_lib' - cached_path = os.path.join(_CACHE_DIR, f"{name}.so") - if os.path.exists(cached_path): + if os.path.exists(os.path.join(_CACHE_DIR, f"{name}.so")): try: - return _import_module_from_library(name, cached_path) + return _import_module_from_library(name, _CACHE_DIR, is_python_module=True) except Exception: # If loading fails, recompile pass @@ -88,7 +88,18 @@ def compress(A: torch.Tensor, if compute_version >= (9, 0): return compress_sm90(A, transposed=transposed, **kwargs) elif compute_version >= (8, 0): - return compress_sm80(A, transposed=transposed) + if transposed: + A = A.t().contiguous() + origin_dtype = A.dtype + if is_float8_dtype(origin_dtype): + fp8_remove_negative_zeros_(A) + A = A.view(torch.int8) + A_sp, E = compress_sm80(A, transposed=False) + if is_float8_dtype(origin_dtype): + A_sp = A_sp.view(origin_dtype) + if transposed: + A_sp = A_sp.t().contiguous() + return A_sp, E else: raise ValueError(f"Unsupported CUDA compute version: {compute_version}. " "Supported versions are sm_80 and sm_90.") @@ -105,6 +116,8 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp transposed (bool): If True, returns a transposed tensor of shape (K, M) """ elem, group = 2, 4 + if dtype == torch.float32: + elem, group = 1, 2 tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group) indice = tensor.topk(elem, dim=-1).indices tensor.scatter_(-1, indice, 0) @@ -114,6 +127,36 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp return tensor.to(dtype) # dtype like float8 might not have randn kernel +def randint_semi_sparse(M: int, + K: int, + low: int, + high: int, + dtype=torch.int32, + device='cuda', + transposed: bool = False): + """ + Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension. + Args: + M (int): Number of rows + K (int): Number of columns + low (int): Lower bound of the random integers + high (int): Upper bound of the random integers + dtype: Data type of the tensor + device: Device to create the tensor on + transposed (bool): If True, returns a transposed tensor of shape (K, M) + """ + elem, group = 2, 4 + if dtype == torch.float32: + elem, group = 1, 2 + tensor = torch.randint(low, high, (M, K), dtype=dtype, device=device).view(M, -1, group) + indice = tensor.topk(elem, dim=-1).indices + tensor.scatter_(-1, indice, 0) + tensor = tensor.view(M, K) + if transposed: + tensor = tensor.t().contiguous() + return tensor + + def arange_semi_sparse(M: int, K: int, dtype=torch.float16, @@ -129,6 +172,8 @@ def arange_semi_sparse(M: int, transposed (bool): If True, returns a transposed tensor of shape (K, M) """ elem, group = 2, 4 + if dtype == torch.float32: + elem, group = 1, 2 tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group) indice = tensor.topk(elem, dim=-1).indices tensor.scatter_(-1, indice, 0) diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index b275708c4..b2905fb1b 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -5,6 +5,22 @@ import numpy as np +def is_float8_dtype(dtype: torch.dtype) -> bool: + return dtype in { + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + } + + +def fp8_remove_negative_zeros_(tensor: torch.Tensor): + assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype" + bits = tensor.view(torch.uint8) + zeros_mask = (tensor == 0) + bits[zeros_mask] = 0x00 + + class TensorSupplyType(Enum): Integer = 1 Uniform = 2