diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 615f173b9..04d849115 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -32,7 +32,7 @@ repos:
args: [--ignore-case]
files: ^docs/spelling_wordlist\.txt$
- repo: https://github.com/pre-commit/mirrors-clang-format
- rev: v21.1.2 # sync with requirements-lint.txt
+ rev: v21.1.6 # sync with requirements-lint.txt
hooks:
- id: clang-format
exclude: |
@@ -41,7 +41,7 @@ repos:
^.+\.json$
)
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.14.3 # sync with requirements-lint.txt
+ rev: v0.14.7 # sync with requirements-lint.txt
hooks:
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix]
diff --git a/3rdparty/tvm b/3rdparty/tvm
index fc7ed0b9c..e8b02611f 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit fc7ed0b9cb7a52eb1c8bf6e8c26bbb8dff3655ce
+Subproject commit e8b02611fd6b803273c5c3e15aa3a030c32dbd30
diff --git a/README.md b/README.md
index d7cdabee5..0c0769e7d 100644
--- a/README.md
+++ b/README.md
@@ -209,7 +209,7 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
-# cuda_source = jit_kernel.get_kernel_source()
+# cuda_source = matmul_relu_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py
index 4e4ed6128..0ff3cd0b6 100644
--- a/benchmark/matmul/benchmark_matmul_sp.py
+++ b/benchmark/matmul/benchmark_matmul_sp.py
@@ -9,7 +9,7 @@
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.contrib import nvcc
-from tilelang.layout import make_metadata_layout
+from tilelang.layout import make_cutlass_metadata_layout
# Configure logger
logger = logging.getLogger(__name__)
@@ -86,7 +86,7 @@ def get_configs(M, N, K):
return configs
-def matmul_sp(M, N, K, accum_dtype):
+def matmul_sp(M, N, K, in_dtype, accum_dtype):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
@@ -161,14 +161,13 @@ def kernel(
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
- dtype = "float16"
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def main(
- A_sparse: T.Tensor((M, K // 2), dtype),
+ A_sparse: T.Tensor((M, K // 2), in_dtype),
E: T.Tensor((M, K // e_factor), e_dtype),
- B: T.Tensor((K, N), dtype),
+ B: T.Tensor((K, N), in_dtype),
C: T.Tensor((M, N), accum_dtype),
):
"""
@@ -187,9 +186,9 @@ def main(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
- A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
+ A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
- B_shared = T.alloc_shared((block_K, block_N), dtype)
+ B_shared = T.alloc_shared((block_K, block_N), in_dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
# Allocate a local fragment for intermediate accumulation
@@ -204,11 +203,9 @@ def main(
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
- make_metadata_layout(
- E, mma_dtype="float16", backend="cutlass", block_k=block_K),
+ make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared:
- make_metadata_layout(
- E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
+ make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
})
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
@@ -220,7 +217,7 @@ def main(
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared
- T.gemm_sp(
+ T.gemm_sp_v2(
A_shared,
E_shared,
B_shared,
@@ -268,7 +265,7 @@ def main(
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
- best_result = matmul_sp(M, N, K, args.accum_dtype)
+ best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
diff --git a/docs/_static/img/sparse_mma_storage_example.png b/docs/_static/img/sparse_mma_storage_example.png
new file mode 100644
index 000000000..0b1639819
Binary files /dev/null and b/docs/_static/img/sparse_mma_storage_example.png differ
diff --git a/docs/deeplearning_operators/matmul_sparse.md b/docs/deeplearning_operators/matmul_sparse.md
new file mode 100644
index 000000000..5910bd6f8
--- /dev/null
+++ b/docs/deeplearning_operators/matmul_sparse.md
@@ -0,0 +1,262 @@
+# Sparse Matrix-Matrix Multiplication with Tile Library
+
+
+
+:::{warning}
+ This document is still **experimental** and may be incomplete.
+
+ This feature is still **experimental** and need further optimization.
+
+ Suggestions and improvements are highly encouraged—please submit a PR!
+:::
+
+:::{tip}
+It's suggested to go through `docs/deeplearning_operators/matmul.md` first.
+
+Example code can be found at `examples/gemm_sp`.
+:::
+
+## Structured sparsity in the NVIDIA Ampere architecture
+
+Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation.
+
+:::{warning}
+ This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X.
+:::
+
+```{figure} ../_static/img/sparse_mma_storage_example.png
+:align: center
+
+Figure: Sparse MMA storage example (from PTX doc)
+```
+
+## Compress a dense tensor
+
+To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata.
+
+Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).
+
+A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.
+
+```python
+from tilelang.utils.sparse import compress
+A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
+```
+
+Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.
+
+> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
+The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
+For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.
+
+
+## `T.gemm_sp` with CUTLASS's compressor
+
+:::{warning}
+
+It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.
+
+:::
+
+A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.
+
+Check comments in below kernel code for required modification.
+
+```python
+def matmul_sp_sm80(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ in_dtype,
+ out_dtype,
+ accum_dtype,
+ num_stages,
+ threads,
+ trans_A,
+ trans_B,
+):
+ is_8_bit = "8" in in_dtype
+ metadata_dtype = 'int32' if is_8_bit else 'int16'
+ E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes
+ A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
+ B_shape = (K, N) if not trans_B else (N, K)
+ A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
+ B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
+
+ import tilelang.language as T
+
+ @T.prim_func
+ def main(
+ A_sparse: T.Tensor(A_sparse_shape, in_dtype),
+ E: T.Tensor((M, K // E_factor), metadata_dtype),
+ B: T.Tensor(B_shape, in_dtype),
+ C: T.Tensor((M, N), out_dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+ A_shared = T.alloc_shared(A_shared_shape, in_dtype)
+ B_shared = T.alloc_shared(B_shared_shape, in_dtype)
+ E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata
+ C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
+ T.annotate_layout({ # Annotate reordered cutlass metadata layout
+ E:
+ make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
+ E_shared:
+ make_cutlass_metadata_layout(
+ E_shared, mma_dtype=in_dtype, arch="8.0"),
+ })
+ T.clear(C_frag)
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
+ T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
+ if trans_A:
+ T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
+ else:
+ T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
+ if trans_B:
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ else:
+ T.copy(B[k * block_K, bx * block_N], B_shared)
+ T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata
+ T.copy(C_frag, C[by * block_M, bx * block_N])
+
+ return main
+```
+
+Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.
+
+## `T.gemm_sp_v2` with a custom compressor
+
+To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.
+
+Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.
+
+The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.
+
+Suppose we have the following row vector:
+```python
+t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten()
+```
+
+The non-zero elements and their corresponding indices are:
+
+```python
+t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten()
+indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten()
+```
+
+The corresponding uint16 metadata is:
+```python
+# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000])
+# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16)
+# Note: the above code is not runnable in python as the interpreter won't take the binary
+# as 2's complement
+metadata_int16 = tensor(-29107)
+```
+
+You can decode an int16 metadata tensor using the following utility:
+```python
+def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
+ assert meta.dtype is torch.int16
+ groups_per_meta = 16 // 4
+ out = []
+ for g in range(groups_per_meta):
+ group_bits = (meta >> (g * 4)) & 0xF
+ idx0 = group_bits & 0x3
+ idx1 = (group_bits >> 2) & 0x3
+ out.append(torch.stack([idx0, idx1], dim=-1))
+ return torch.concat(out, dim=-1).view(meta.shape[0], -1)
+```
+
+The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.
+
+For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
+
+If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.
+
+```python
+
+@tilelang.jit(out_idx=[1, 2], pass_configs={
+ tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
+})
+def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
+ e_factor, e_dtype = ARCH_INFO["8.0"]
+ e_K = K // e_factor
+ elem, group = 2, 4
+
+ assert M % block_M == 0, "M must be divisible by block_M"
+ assert K % block_K == 0, "K must be divisible by block_K"
+ assert K % e_factor == 0, "K must be divisible by e_factor"
+ assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
+
+ @T.prim_func
+ def kernel(
+ A: T.Tensor((M, K), dtype),
+ A_sp: T.Tensor((M, K // 2), dtype),
+ E: T.Tensor((M, e_K), e_dtype),
+ ):
+ with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
+ A_shared = T.alloc_shared((block_M, block_K), dtype)
+ A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
+ E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
+ if use_cutlass_layout: # NOTE: Make sure compressor metadata layout
+ T.annotate_layout({ # is same with your computation kernel
+ E:
+ make_cutlass_metadata_layout(
+ E, mma_dtype="float16", arch="8.0", block_k=block_K),
+ E_shared:
+ make_cutlass_metadata_layout(
+ E_shared,
+ mma_dtype="float16",
+ arch="8.0",
+ block_k=block_K),
+ })
+ T.clear(A_sp_shared)
+ T.clear(E_shared)
+ non_zero_cnt = T.alloc_local((1, ), dtype="uint8")
+ non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8")
+ T.copy(A[bx * block_M, by * block_K], A_shared)
+ for tm in T.Parallel(block_M):
+ for g_i in range(0, block_K // group):
+ a_k = g_i * group
+ T.clear(non_zero_cnt)
+ T.clear(non_zero_elt_log_idx)
+ for i in range(group):
+ val = A_shared[tm, a_k + i]
+ if val != 0.0:
+ non_zero_elt_log_idx[non_zero_cnt[0]] = i
+ A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
+ non_zero_cnt[0] += 1
+ if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
+ non_zero_elt_log_idx[0] = 0
+ non_zero_elt_log_idx[1] = 3
+ A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
+ A_sp_shared[tm, a_k // 2] = 0.0
+ elif non_zero_cnt[0] == 1:
+ A_sp_shared[tm, a_k // 2 + 1] = 0
+ non_zero_elt_log_idx[1] = 3
+ for i in T.serial(elem):
+ val = non_zero_elt_log_idx[i]
+ E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
+ T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
+ T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
+
+ return kernel
+```
+
+## A note on `gemm_sp` and `gemm_sp_v2`
+
+Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.
+
+However, fixing a specific layout introduces several potential issues:
+
+1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.
+
+2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.
+
+3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)
+
+`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout.
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index 9f7947766..45e7f5ea4 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -33,6 +33,7 @@ tutorials/auto_tuning
deeplearning_operators/elementwise
deeplearning_operators/gemv
deeplearning_operators/matmul
+deeplearning_operators/matmul_sparse
deeplearning_operators/deepseek_mla
:::
diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
index ed44aab69..0e2c437e3 100644
--- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
+++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
@@ -51,7 +51,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
- is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
+ is_float8 = in_dtype in [
+ "float8_e4m3",
+ "float8_e5m2",
+ "float8_e4m3fn",
+ "float8_e5m2fnuz",
+ ]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py
new file mode 100644
index 000000000..5125aed07
--- /dev/null
+++ b/examples/gemm_sp/example_custom_compress.py
@@ -0,0 +1,364 @@
+# Copyright (c) Tile-AI Corporation.
+# Licensed under the MIT License.
+import argparse
+
+import tilelang
+import tilelang.language as T
+
+from tilelang.layout import make_cutlass_metadata_layout
+from tilelang.utils.sparse import randn_semi_sparse
+from tilelang.utils.tensor import torch_assert_close
+
+from triton.testing import do_bench
+
+import torch
+
+torch.manual_seed(42)
+
+DEFAULT_CONFIG = { # take best config from autotune script
+ "4090": {
+ 'float': {
+ 'block_M': 128,
+ 'block_N': 64,
+ 'block_K': 64,
+ 'num_stages': 1,
+ 'thread_num': 128,
+ 'policy': T.GemmWarpPolicy.Square,
+ 'enable_rasterization': True
+ },
+ 'float16': {
+ 'block_M': 256,
+ 'block_N': 128,
+ 'block_K': 64,
+ 'num_stages': 2,
+ 'thread_num': 128,
+ 'policy': T.GemmWarpPolicy.Square,
+ 'enable_rasterization': True
+ }
+ },
+ "h20": {
+ 'float': {
+ 'block_M': 128,
+ 'block_N': 64,
+ 'block_K': 128,
+ 'num_stages': 3,
+ 'thread_num': 128,
+ 'policy': T.GemmWarpPolicy.Square,
+ 'enable_rasterization': True
+ },
+ 'float16': {
+ 'block_M': 128,
+ 'block_N': 64,
+ 'block_K': 128,
+ 'num_stages': 3,
+ 'thread_num': 128,
+ 'policy': T.GemmWarpPolicy.Square,
+ 'enable_rasterization': True
+ }
+ }
+}
+
+ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
+
+
+@tilelang.jit(out_idx=[-1])
+def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
+ thread_num, policy, enable_rasterization, use_cutlass_layout):
+ e_factor, e_dtype = (16, "int16")
+
+ @T.prim_func
+ def gemm_sp_fp16_custom_compress(
+ A_sparse: T.Tensor((M, K // 2), 'float16'),
+ E: T.Tensor((M, K // e_factor), e_dtype),
+ B: T.Tensor((K, N), 'float16'),
+ C: T.Tensor((M, N), accum_dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
+ A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
+ E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
+ B_shared = T.alloc_shared((block_K, block_N), 'float16')
+ C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
+ if use_cutlass_layout:
+ T.annotate_layout({
+ E:
+ make_cutlass_metadata_layout(
+ E, mma_dtype="float16", arch="8.0", block_k=block_K),
+ E_shared:
+ make_cutlass_metadata_layout(
+ E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
+ })
+ T.clear(C_local)
+ T.disable_warp_group_reg_alloc()
+ T.use_swizzle(panel_size=10, enable=enable_rasterization)
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
+ T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
+ T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
+ T.copy(B[k * block_K, bx * block_N], B_shared)
+ T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
+
+ T.copy(C_local, C_shared)
+ T.copy(C_shared, C[by * block_M, bx * block_N])
+
+ return gemm_sp_fp16_custom_compress
+
+
+def torch_compress(dense):
+ """
+ A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
+ """
+ if dense.dim() != 2:
+ raise RuntimeError(
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
+
+ m, k = dense.shape
+
+ meta_dtype = torch.int8
+ if dense.dtype == torch.int8:
+ meta_dtype = torch.int32
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
+ meta_dtype = torch.int16
+ else:
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
+ if quadbits_per_meta_elem not in (4, 8):
+ raise RuntimeError("Invalid number of elements per meta element calculated")
+
+ if meta_dtype == torch.int32:
+ if m % 16 != 0:
+ raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16")
+ else:
+ if m % 32 != 0:
+ raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32")
+ if k % (4 * quadbits_per_meta_elem) != 0:
+ raise RuntimeError(
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
+ )
+
+ if dense.dtype != torch.float:
+ ksparse = 4
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
+ m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
+ else:
+ ksparse = 2
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
+ m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
+
+ # Encoding quadruples of True/False values as follows:
+ # [True, True, False, False] -> 0b0100
+ # [True, False, True, False] -> 0b1000
+ # [False, True, True, False] -> 0b1001
+ # [True, False, False, True ] -> 0b1100
+ # [False, True, False, True ] -> 0b1101
+ # [False, False, True, True ] -> 0b1110
+ # Thus, lower two bits in the encoding are index of the True value
+ # at the lowest index in the quadruple, and the higher two bits in
+ # the encoding are index of the other True value in the quadruple.
+ # In case there are less than two True values, than False value or
+ # values at some index or indices are considered True for the
+ # encoding. In case there are more than two True values, then the
+ # excess True value(s) at some indices are considered False for
+ # the encoding. The exact encodings used for these cases are as
+ # follows:
+ # [False, False, False, False] -> 0b1110
+ # [False, False, False, True ] -> 0b1110
+ # [False, False, True, False] -> 0b1110
+ # [False, True, False, False] -> 0b1001
+ # [False, True, True, True ] -> 0b1101
+ # [True, False, False, False] -> 0b1000
+ # [True, False, True, True ] -> 0b1100
+ # [True, True, False, True ] -> 0b0100
+ # [True, True, True, False] -> 0b0100
+ # [True, True, True, True ] -> 0b0100
+ # These particular encodings are chosen, with the help of Espresso
+ # logic minimizer software, for the purpose of minimization of
+ # corresponding Boolean functions, that translate non-zero flags
+ # into encoding bits. Note also possible choices for the first
+ # and last of these encodings were limited only to (0b0100,
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
+ # case.
+
+ expr0 = m0 & m1
+ expr1 = ~m0 & m1
+ expr2 = ~m0 & ~m1
+ bit0 = expr1
+ bit1 = expr2
+ bit2 = expr0 | expr2 | m3
+ bit3 = expr1 | ~m1
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
+
+ if dense.dtype != torch.float:
+ sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
+ else:
+ sparse = dense_2.gather(-1,
+ idxs0.unsqueeze(-1) // 2).view(
+ m, k // 2) # type: ignore[possibly-undefined]
+
+ meta_4 = idxs0 | (idxs1 << 2)
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
+
+ if quadbits_per_meta_elem == 4:
+ meta = (
+ meta_n[:, :, 0]
+ | (meta_n[:, :, 1] << 4)
+ | (meta_n[:, :, 2] << 8)
+ | (meta_n[:, :, 3] << 12))
+ elif quadbits_per_meta_elem == 8:
+ meta = (
+ meta_n[:, :, 0]
+ | (meta_n[:, :, 1] << 4)
+ | (meta_n[:, :, 2] << 8)
+ | (meta_n[:, :, 3] << 12)
+ | (meta_n[:, :, 4] << 16)
+ | (meta_n[:, :, 5] << 20)
+ | (meta_n[:, :, 6] << 24)
+ | (meta_n[:, :, 7] << 28))
+
+ return (sparse, meta)
+
+
+def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
+ assert meta.dtype is torch.int16
+ groups_per_meta = 16 // 4 # 4 groups per uint16
+ out = []
+ for g in range(groups_per_meta):
+ group_bits = (meta >> (g * 4)) & 0xF
+ idx0 = group_bits & 0x3
+ idx1 = (group_bits >> 2) & 0x3
+ out.append(torch.stack([idx0, idx1], dim=-1))
+ return torch.concat(out, dim=-1).view(meta.shape[0], -1)
+
+
+@tilelang.jit(
+ out_idx=[1, 2], pass_configs={
+ tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
+ })
+def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
+ e_factor, e_dtype = ARCH_INFO["8.0"]
+ e_K = K // e_factor
+ elem, group = 2, 4
+
+ assert M % block_M == 0, "M must be divisible by block_M"
+ assert K % block_K == 0, "K must be divisible by block_K"
+ assert K % e_factor == 0, "K must be divisible by e_factor"
+ assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
+
+ @T.prim_func
+ def kernel(
+ A: T.Tensor((M, K), dtype),
+ A_sp: T.Tensor((M, K // 2), dtype),
+ E: T.Tensor((M, e_K), e_dtype),
+ ):
+ with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
+ A_shared = T.alloc_shared((block_M, block_K), dtype)
+ A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
+ E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
+ if use_cutlass_layout:
+ T.annotate_layout({
+ E:
+ make_cutlass_metadata_layout(
+ E, mma_dtype="float16", arch="8.0", block_k=block_K),
+ E_shared:
+ make_cutlass_metadata_layout(
+ E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
+ })
+ T.clear(A_sp_shared)
+ T.clear(E_shared)
+ # TODO: alloc_var seems buggy here
+ non_zero_cnt = T.alloc_local((1,), dtype="uint8")
+ non_zero_elt_log_idx = T.alloc_local((elem,), dtype="uint8")
+ T.copy(A[bx * block_M, by * block_K], A_shared)
+ for tm in T.Parallel(block_M):
+ for g_i in range(0, block_K // group):
+ a_k = g_i * group
+ non_zero_cnt[0] = 0
+ for i in range(elem):
+ non_zero_elt_log_idx[i] = 0
+ for i in range(group):
+ val = A_shared[tm, a_k + i]
+ if val != 0.0:
+ non_zero_elt_log_idx[non_zero_cnt[0]] = i
+ A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
+ non_zero_cnt[0] += 1
+ # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
+ if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
+ non_zero_elt_log_idx[0] = 0
+ non_zero_elt_log_idx[1] = 3
+ A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
+ A_sp_shared[tm, a_k // 2] = 0.0
+ elif non_zero_cnt[0] == 1:
+ A_sp_shared[tm, a_k // 2 + 1] = 0
+ non_zero_elt_log_idx[1] = 3
+ for i in T.serial(elem):
+ val = non_zero_elt_log_idx[i]
+ E_shared[tm, a_k // e_factor] |= T.shift_left(
+ val, 4 * (g_i % (e_factor // group)) + 2 * i)
+ T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
+ T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
+
+ return kernel
+
+
+def main():
+
+ parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
+ parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
+ parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
+ parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
+ parser.add_argument(
+ "--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor")
+ parser.add_argument(
+ "--use_torch_compressor", action='store_true', help="Use torch sparse for reference")
+ parser.add_argument(
+ "--accum_dtype",
+ type=str,
+ default="float",
+ choices=["float", "float16"],
+ help="Accumulation datatype")
+ parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
+ args = parser.parse_args()
+ kernel = matmul_sp_fp16_custom_compress(
+ args.m,
+ args.n,
+ args.k,
+ args.accum_dtype,
+ **DEFAULT_CONFIG[args.cfg][args.accum_dtype],
+ use_cutlass_layout=args.use_cutlass_layout)
+
+ a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
+ b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
+
+ if args.use_torch_compressor:
+ assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
+ a_sparse, e = torch_compress(a)
+ else:
+ a_sparse, e = compress_kernel(
+ args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(
+ a)
+
+ c = kernel(a_sparse, e, b)
+
+ ref_c = a @ b
+
+ assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
+ torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
+ print(
+ f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}"
+ )
+
+ latency = do_bench(lambda: kernel(a_sparse, e, b))
+ ref_latency = do_bench(lambda: a @ b)
+
+ total_flops = 2 * args.m * args.n * args.k
+ tflops = total_flops / latency / 1e9
+ ref_tflops = total_flops / ref_latency / 1e9
+ print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
+ print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py
index 505f2b883..91682a9e4 100644
--- a/examples/gemm_sp/example_gemm_sp.py
+++ b/examples/gemm_sp/example_gemm_sp.py
@@ -5,7 +5,7 @@
import tilelang
import tilelang.language as T
-from tilelang.layout import make_metadata_layout
+from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc
from triton.testing import do_bench
@@ -14,9 +14,7 @@
arch = nvcc.get_target_compute_version()
-ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
-
-default_config = { # take best config from autotune script
+DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
@@ -59,6 +57,8 @@
}
}
+ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
+
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
@@ -84,15 +84,11 @@ def gemm_sp_fp16(
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
- make_metadata_layout(
- E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch),
+ make_cutlass_metadata_layout(
+ E, mma_dtype="float16", block_k=block_K, arch=arch),
E_shared:
- make_metadata_layout(
- E_shared,
- mma_dtype="float16",
- backend="cutlass",
- block_k=block_K,
- arch=arch),
+ make_cutlass_metadata_layout(
+ E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
@@ -117,10 +113,10 @@ def main():
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
- parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
+ parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
- **default_config[args.cfg][args.accum_dtype])
+ **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
@@ -128,7 +124,7 @@ def main():
a_sparse, e = compress(
a,
transposed=False,
- block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
+ block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
c = kernel(a_sparse, e, b)
diff --git a/examples/gemm_sp/test_example_gemm_sp.py b/examples/gemm_sp/test_example_gemm_sp.py
new file mode 100644
index 000000000..fe26df144
--- /dev/null
+++ b/examples/gemm_sp/test_example_gemm_sp.py
@@ -0,0 +1,16 @@
+import tilelang.testing
+
+import example_custom_compress
+import example_gemm_sp
+
+
+def test_example_custom_compress():
+ example_custom_compress.main()
+
+
+def test_example_gemm_sp():
+ example_gemm_sp.main()
+
+
+if __name__ == "__main__":
+ tilelang.testing.main()
diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py
index 58e0114be..3772dc6bd 100644
--- a/examples/gemv/example_gemv.py
+++ b/examples/gemv/example_gemv.py
@@ -360,7 +360,7 @@ def main(do_bench: bool = True):
print("Test passed!")
- if not do_bench:
+ if do_bench:
best_result = get_autotuned_kernel(N, K)
best_config = best_result.config
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
index 59c79c283..8707c9430 100644
--- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
+++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
@@ -1,7 +1,7 @@
import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
-from tilelang.layout import make_metadata_layout
+from tilelang.layout import make_cutlass_metadata_layout
import tilelang.testing
@@ -40,15 +40,11 @@ def main(
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
- make_metadata_layout(
- E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
+ make_cutlass_metadata_layout(
+ E, mma_dtype="float16", arch="9.0", block_k=block_K),
E_shared:
- make_metadata_layout(
- E_shared,
- mma_dtype="float16",
- arch="9.0",
- backend="cutlass",
- block_k=block_K),
+ make_cutlass_metadata_layout(
+ E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
})
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc
index 1a49b7706..4ae19bafc 100644
--- a/src/op/atomic_add.cc
+++ b/src/op/atomic_add.cc
@@ -539,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
}
-TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
+TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd)
.set_num_inputs(2)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/src/op/copy.cc b/src/op/copy.cc
index 1bd548bc5..7bef87d64 100644
--- a/src/op/copy.cc
+++ b/src/op/copy.cc
@@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) {
}
} else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
- } else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) {
+ } else if (dtype.is_float8()) {
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (dtype.is_int()) {
switch (dtype.bits()) {
@@ -2037,7 +2037,7 @@ Array TMAIm2ColDesc::EncodeCallArgs() const {
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy
// - Marked as opaque since it has side effects (memory writes)
-TIR_REGISTER_TL_OP(Copy, copy)
+TIR_REGISTER_TL_TILE_OP(Copy, copy)
.set_num_inputs(5)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
@@ -2062,7 +2062,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T,
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
// dilation, padding, eviction_policy
// - Marked as opaque since it has side effects (memory writes)
-TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
+TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(9)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/src/op/fill.cc b/src/op/fill.cc
index 5a773768a..714e97ad2 100644
--- a/src/op/fill.cc
+++ b/src/op/fill.cc
@@ -209,7 +209,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
return {};
}
-TIR_REGISTER_TL_OP(Fill, fill)
+TIR_REGISTER_TL_TILE_OP(Fill, fill)
.set_num_inputs(2)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc
index effc4baf0..f542b2d91 100644
--- a/src/op/finalize_reducer.cc
+++ b/src/op/finalize_reducer.cc
@@ -159,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const {
return TileOperator(node);
}
-TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
+TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/src/op/gemm.cc b/src/op/gemm.cc
index 5a98cba69..57c02b0b5 100644
--- a/src/op/gemm.cc
+++ b/src/op/gemm.cc
@@ -361,13 +361,7 @@ bool GemmNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
+ else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
@@ -380,13 +374,7 @@ bool GemmNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
+ else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
@@ -826,7 +814,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
return results;
}
-TIR_REGISTER_TL_OP(Gemm, gemm)
+TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
.set_num_inputs(5)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc
index aa6c02823..378fcc6a5 100644
--- a/src/op/gemm_py.cc
+++ b/src/op/gemm_py.cc
@@ -182,13 +182,7 @@ bool GemmPyNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
+ else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
@@ -201,13 +195,7 @@ bool GemmPyNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
- return (!transA_) && transB_ && k_ % 32 == 0;
- else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
+ else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
@@ -318,7 +306,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
return results;
}
-TIR_REGISTER_TL_OP(GemmPy, gemm_py)
+TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py)
.set_num_inputs(5)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc
index df923d0e9..4c0ae08b9 100644
--- a/src/op/gemm_sp.cc
+++ b/src/op/gemm_sp.cc
@@ -302,12 +302,25 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
return results;
}
-TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
+TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp)
.set_num_inputs(5)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
-TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); }
+TVM_REGISTER_OP("tl.GemmSPWarpPolicy")
+ .set_attr("TScriptPrinterName", "GemmSPWarpPolicy");
+TVM_FFI_STATIC_INIT_BLOCK() {
+ GemmSPNode::RegisterReflection();
+ GemmSPWarpPolicyNode::RegisterReflection();
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def(
+ "tl.GemmSPWarpPolicyComputeWarpPartition",
+ [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target,
+ bool use_wgmma, int bits) {
+ policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);
+ return;
+ });
+}
} // namespace tl
} // namespace tvm
diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h
index aae5b27bf..a634e922f 100644
--- a/src/op/gemm_sp.h
+++ b/src/op/gemm_sp.h
@@ -23,6 +23,14 @@ class GemmSPWarpPolicyNode : public GemmWarpPolicyNode {
int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode);
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef()
+ .def_ro("policy_type", &GemmSPWarpPolicyNode::policy_type)
+ .def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp)
+ .def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp);
+ }
};
class GemmSPWarpPolicy : public ObjectRef {
diff --git a/src/op/gemm_sp_py.cc b/src/op/gemm_sp_py.cc
new file mode 100644
index 000000000..6ad8ca9b5
--- /dev/null
+++ b/src/op/gemm_sp_py.cc
@@ -0,0 +1,289 @@
+/*!
+ * \file tl/op/gemm_sp_py.cc
+ * \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP)
+ * operators
+ */
+
+#include "gemm_sp_py.h"
+#include "utils.h"
+
+#include "builtin.h"
+#include
+#include
+#include
+#include
+
+#include "../target/utils.h"
+#include "tvm/ffi/string.h"
+
+namespace tvm {
+namespace tl {
+
+using namespace tir;
+
+/**
+ * @brief Construct a Gemm operator from serialized TL arguments and a buffer
+ * map.
+ *
+ * This constructor deserializes operator parameters from `args` and resolves
+ * buffer references via `vmap`, populating an internal GemmSPPyNode with:
+ * - device pointers for A, E, B, C and their corresponding Buffer objects,
+ * - transpose flags for A and B,
+ * - matrix dimensions M, N, K,
+ * - warp allocation policy and clear_accum flag,
+ * - strides and memory offsets for A and B,
+ * - optional kPack (must be 1 or 2) and optional wg_wait.
+ *
+ * The populated GemmSPPyNode is stored into the wrapper's internal `data_`.
+ *
+ * @param args Positional serialized arguments produced by the TL frontend:
+ * expected layout is:
+ * [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
+ * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
+ * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
+ * (optional) kPack (Int), (optional) wg_wait (Int)]
+ * @param vmap Mapping from access pointer vars to Buffer objects used to
+ * resolve the Buffer corresponding to each pointer argument.
+ *
+ * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
+ * fails with an ICHECK (runtime assertion). No other validation is
+ * performed here.
+ */
+GemmSPPy::GemmSPPy(Array args) {
+ ObjectPtr node = tvm::ffi::make_object();
+
+ node->aRegion_ = NormalizeToBufferRegion(args[0]);
+ node->eRegion_ = NormalizeToBufferRegion(args[1]);
+ node->bRegion_ = NormalizeToBufferRegion(args[2]);
+ node->cRegion_ = NormalizeToBufferRegion(args[3]);
+
+ node->A = node->aRegion_->buffer;
+ node->E = node->eRegion_->buffer;
+ node->B = node->bRegion_->buffer;
+ node->C = node->cRegion_->buffer;
+
+ node->trans_A = args[4].as().value();
+ node->trans_B = args[5].as().value();
+ node->trans_E = args[6].as().value();
+ node->M = args[7].as().value()->value;
+ node->N = args[8].as().value()->value;
+ node->K = args[9].as().value()->value;
+ node->policy = GemmWarpPolicy(args[10].as().value()->value);
+ node->clear_accum = args[11].as().value();
+ node->stride_A = args[12].as().value()->value;
+ node->stride_B = args[13].as().value()->value;
+ node->offset_A = args[14].as().value()->value;
+ node->offset_B = args[15].as().value()->value;
+ if (args.size() > 16) {
+ node->kPack = args[16].as().value()->value;
+ if (node->kPack != 1 && node->kPack != 2) {
+ ICHECK(false) << "kPack must be 1 or 2";
+ }
+ }
+ if (args.size() > 17) {
+ node->wg_wait = args[17].as().value()->value;
+ }
+ data_ = std::move(node);
+}
+
+/**
+ * @brief Create a copy of this GemmSPPyNode as a TileOperator.
+ *
+ * Constructs a new GemmSPPyNode by copying the current node state and returns
+ * it wrapped in a GemmSPPy TileOperator.
+ *
+ * @return TileOperator A GemmSPPy operator that owns a copy of this node.
+ */
+TileOperator GemmSPPyNode::Clone() const {
+ auto op = tvm::ffi::make_object(*this);
+ return GemmSPPy(op);
+}
+
+GemmInst GemmSPPyNode::GetGemmInst(int block_size, Target target) const {
+ int warp_size = TargetGetWarpSize(target);
+ int num_warps = block_size / warp_size;
+ bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
+ (num_warps % 4 == 0) && CheckWGMMA();
+ if (allow_wgmma) {
+ return GemmInst::kWGMMA;
+ } else if (TargetIsCDNA(target)) {
+ return GemmInst::kMFMA;
+ } else if (TargetIsCuda(target)) {
+ return GemmInst::kMMA;
+ } else {
+ ICHECK(0) << "Unsupported target for gemm: " << target->str();
+ }
+}
+
+/**
+ * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
+ *
+ * Evaluates device-memory placement, data-type combinations, transpose flags,
+ * and K divisibility constraints required for the Hopper WGMMA code path.
+ *
+ * The check returns true only when:
+ * - B resides in shared memory ("shared" or "shared.dyn"); and
+ * - (C, A, B) dtypes match one of the supported combinations below and K
+ * satisfies the required alignment; and
+ * - for combinations that require specific orientations, A is not transposed
+ * and B is transposed.
+ *
+ * Supported combinations and constraints:
+ * - C=float16:
+ * - A=float16, B=float16: K % 16 == 0
+ * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
+ * 32 == 0
+ * - C=float32:
+ * - A=float16, B=float16: K % 16 == 0
+ * - A=bfloat16, B=bfloat16: K % 16 == 0
+ * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
+ * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
+ * - C=int32:
+ * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
+ * and K % 32 == 0
+ *
+ * @return true if WGMMA is supported for the current buffers, dtypes, and
+ * transpose/shape constraints; false otherwise.
+ */
+bool GemmSPPyNode::CheckWGMMA() const {
+ return false; // not supported yet
+ // if (B.scope() != "shared.dyn" && B.scope() != "shared") {
+ // return false;
+ // }
+
+ // if (C->dtype == DataType::Float(16)) {
+ // if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
+ // return K % 16 == 0;
+ // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else
+ // return false;
+ // } else if (C->dtype == DataType::Float(32)) {
+ // if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
+ // return K % 16 == 0;
+ // else if (A->dtype == DataType::BFloat(16) &&
+ // B->dtype == DataType::BFloat(16))
+ // return K % 16 == 0;
+ // else if (A->dtype == DataType::Float(32) && B->dtype ==
+ // DataType::Float(32))
+ // return (!trans_A) && trans_B && K % 8 == 0;
+ // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else
+ // return false;
+ // } else if (C->dtype == DataType::Int(32)) {
+ // if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
+ // return (!trans_A) && trans_B && K % 32 == 0;
+ // else
+ // return false;
+ // } else {
+ // return false;
+ // }
+}
+
+/**
+ * @brief Parse and return the numeric GPU architecture from a Target's "arch"
+ * attribute.
+ *
+ * Examines the target's "arch" string and, if it matches the pattern
+ * "sm_", returns as an int. If the attribute is present but does not
+ * match that pattern, returns 0.
+ *
+ * Preconditions: the target must have an "arch" attribute (this is checked via
+ * ICHECK).
+ *
+ * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
+ * the arch string does not match "sm_".
+ */
+static int GetArchInt(Target target) {
+ int arch_int = 0;
+ auto s = target->GetAttr("arch");
+ ICHECK(s.has_value());
+ std::string arch = s.value();
+ if (arch.rfind("sm_", 0) == 0) {
+ arch_int = std::stoi(arch.substr(3));
+ } else {
+ arch_int = 0;
+ }
+ return arch_int;
+}
+
+Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
+ auto block_size = *as_const_int(T.thread_bounds->extent);
+ GemmInst gemm_inst = GetGemmInst(block_size, T.target);
+
+ auto [warp_m, warp_n] =
+ policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst);
+
+ if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) {
+ auto prim_func =
+ Downcast((*f)(tvm::ffi::GetRef(this), T.target,
+ T.thread_bounds, T.thread_var));
+ ICHECK(prim_func->attrs.defined());
+ auto global_symbol = prim_func->attrs.GetAttr("global_symbol");
+ ICHECK(global_symbol.has_value());
+ if (prim_func->body.as()) {
+ BlockRealize block_realize = Downcast(prim_func->body);
+ auto block = block_realize->block;
+ {
+ BlockNode *n = block.CopyOnWrite();
+ n->name_hint = global_symbol.value();
+ }
+ return BlockRealize(block_realize->iter_values, block_realize->predicate,
+ block);
+ }
+ // warp with block realize node
+ return BlockRealize(
+ /*iter_values=*/Array(),
+ /*predicate=*/const_true(),
+ /*block=*/
+ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
+ /*name_hint=*/global_symbol.value(), prim_func->body));
+ } else {
+ LOG(FATAL) << "No lower function found for gemm_sp_py";
+ }
+}
+
+LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T,
+ InferLevel level) const {
+ if (completed_)
+ return {};
+ LayoutMap results;
+
+ if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) {
+ results = Downcast(
+ (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds));
+ } else {
+ LOG(FATAL) << "No infer layout function found for gemm_sp_py";
+ }
+
+ completed_ = true;
+ return results;
+}
+
+TIR_REGISTER_TL_TILE_OP(GemmSPPy, gemm_sp_py)
+ .set_num_inputs(5)
+ .set_attr("TCallEffectKind",
+ Integer(CallEffectKind::kOpaque));
+
+TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); }
+} // namespace tl
+} // namespace tvm
diff --git a/src/op/gemm_sp_py.h b/src/op/gemm_sp_py.h
new file mode 100644
index 000000000..2f79c5e15
--- /dev/null
+++ b/src/op/gemm_sp_py.h
@@ -0,0 +1,94 @@
+/*!
+ * \file tl/op/gemm_sp_py.h
+ * \brief Define gemm_sp_py operator.
+ *
+ */
+
+// TODO: @botbw: remove redundant code with gemm_py.h
+
+#ifndef TVM_TL_OP_GEMM_SP_PY_H_
+#define TVM_TL_OP_GEMM_SP_PY_H_
+
+#include "gemm_sp.h"
+#include "operator.h"
+
+namespace tvm {
+
+namespace tl {
+
+using namespace tir;
+
+class GemmSPPyNode : public TileOperatorNode {
+public:
+ bool CheckWGMMA() const;
+ tir::Buffer A, E, B, C;
+ // pointer to the A, E, B, C
+ BufferRegion aRegion_, eRegion_, bRegion_, cRegion_;
+ bool trans_A, trans_B, trans_E;
+ int M, N, K;
+ int stride_A, stride_B;
+ int offset_A, offset_B;
+ PrimExpr clear_accum = const_false();
+ // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
+ // only will be enabled under cdna mfma instructions
+ int kPack = 1;
+ int wg_wait = 0;
+
+ // use GemmWarp Policy here as the atom size are flexible in v2
+ mutable GemmWarpPolicy policy;
+
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode,
+ TileOperatorNode);
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef()
+ .def_ro("A", &GemmSPPyNode::A)
+ .def_ro("E", &GemmSPPyNode::E)
+ .def_ro("B", &GemmSPPyNode::B)
+ .def_ro("C", &GemmSPPyNode::C)
+ .def_ro("aRegion", &GemmSPPyNode::aRegion_)
+ .def_ro("eRegion", &GemmSPPyNode::eRegion_)
+ .def_ro("bRegion", &GemmSPPyNode::bRegion_)
+ .def_ro("cRegion", &GemmSPPyNode::cRegion_)
+ .def_ro("trans_A", &GemmSPPyNode::trans_A)
+ .def_ro("trans_B", &GemmSPPyNode::trans_B)
+ .def_ro("trans_E", &GemmSPPyNode::trans_E)
+ .def_ro("M", &GemmSPPyNode::M)
+ .def_ro("N", &GemmSPPyNode::N)
+ .def_ro("K", &GemmSPPyNode::K)
+ .def_ro("stride_A", &GemmSPPyNode::stride_A)
+ .def_ro("stride_B", &GemmSPPyNode::stride_B)
+ .def_ro("offset_A", &GemmSPPyNode::offset_A)
+ .def_ro("offset_B", &GemmSPPyNode::offset_B)
+ .def_ro("clear_accum", &GemmSPPyNode::clear_accum)
+ .def_ro("kPack", &GemmSPPyNode::kPack)
+ .def_ro("wg_wait", &GemmSPPyNode::wg_wait)
+ .def_ro("policy", &GemmSPPyNode::policy);
+ }
+
+ Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
+ LayoutMap InferLayout(const LayoutInferArgs &T,
+ InferLevel level) const override;
+
+ TileOperator Clone() const;
+
+private:
+ // Target GEMM instruction
+ GemmInst GetGemmInst(int block_size, Target target) const;
+
+ mutable bool completed_ = false;
+};
+
+class GemmSPPy : public TileOperator {
+public:
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator,
+ GemmSPPyNode);
+ TVM_DLL GemmSPPy(Array args);
+ static const Op &Get();
+};
+
+} // namespace tl
+} // namespace tvm
+
+#endif // TVM_TL_OP_GEMM_SP_PY_H_
\ No newline at end of file
diff --git a/src/op/operator.h b/src/op/operator.h
index 0d9f859a7..1453f9c1e 100644
--- a/src/op/operator.h
+++ b/src/op/operator.h
@@ -77,12 +77,12 @@ TileOperator ParseOperator(Stmt stmt);
using OpBuilderFunc = ffi::TypedFunction)>;
-#define TIR_REGISTER_TL_OP(Entry, OpName) \
+#define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \
const Op &Entry::Get() { \
- static const Op &op = Op::Get("tl." #OpName); \
+ static const Op &op = Op::Get("tl.tileop." #OpName); \
return op; \
} \
- TVM_REGISTER_OP("tl." #OpName) \
+ TVM_REGISTER_OP("tl.tileop." #OpName) \
.set_attr("TScriptPrinterName", #OpName) \
.set_attr( \
"TLOpBuilder", [](Array args) { return Entry(args); })
diff --git a/src/op/parallel.cc b/src/op/parallel.cc
index 0d09cc129..94572098d 100644
--- a/src/op/parallel.cc
+++ b/src/op/parallel.cc
@@ -252,17 +252,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
forward_vars.push_back(
IterVar(Range(0, s), Var(), IterVarType::kDataPar));
}
- Array forward_index;
- for (const auto &iv : forward_vars) {
- forward_index.push_back(iv->var);
- }
Var rep;
auto rep_iter =
IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar);
+ // Use default fragment indexing (single output dim) to
+ // stay consistent with other ops (e.g., ReduceOp), and
+ // bind the thread range for comparability.
const PrimExpr &forward_thread = rep;
- results.Set(buffer, Fragment(forward_vars, forward_index,
- forward_thread, rep_iter));
+ auto frag = Fragment(forward_vars, /*forward_index=*/{}, forward_thread,
+ rep_iter)
+ ->BindThreadRange(T.thread_bounds);
+ results.Set(buffer, frag);
}
}
return results;
diff --git a/src/op/reduce.cc b/src/op/reduce.cc
index caf9198a7..40c9b83cd 100644
--- a/src/op/reduce.cc
+++ b/src/op/reduce.cc
@@ -478,7 +478,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}
-TIR_REGISTER_TL_OP(ReduceOp, reduce)
+TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce)
.set_num_inputs(4)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
@@ -563,7 +563,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}
-TIR_REGISTER_TL_OP(CumSumOp, cumsum)
+TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum)
.set_num_inputs(4)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/src/op/region.cc b/src/op/region.cc
index 2a1f27456..25e78eba8 100644
--- a/src/op/region.cc
+++ b/src/op/region.cc
@@ -76,17 +76,7 @@ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}
-const Op &RegionOp::Get() {
- static const Op &op = Op::Get("tl.region");
- return op;
-}
-
-TVM_REGISTER_OP("tl.region")
- .set_attr("TScriptPrinterName", "region")
- .set_attr("TLOpBuilder",
- [](Array args) {
- return RegionOp(args);
- })
+TIR_REGISTER_TL_TILE_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr("TCallEffectKind",
Integer(CallEffectKind::kPure));
diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h
index 3d994bf5c..8b6ff61ba 100644
--- a/src/op/tcgen5_meta.h
+++ b/src/op/tcgen5_meta.h
@@ -52,10 +52,8 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
} else {
FAIL;
}
- } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() ||
- ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() ||
- ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() ||
- ab_dtype.is_float4_e2m1fn()) &&
+ } else if ((ab_dtype.is_float8() || ab_dtype.is_float6_e2m3fn() ||
+ ab_dtype.is_float6_e3m2fn() || ab_dtype.is_float4_e2m1fn()) &&
((c_dtype.is_float() && c_dtype.bits() == 32) ||
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
if (K % 32 != 0)
diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc
index 99512b8be..65f23d8dd 100644
--- a/src/target/codegen_cuda.cc
+++ b/src/target/codegen_cuda.cc
@@ -312,7 +312,12 @@ std::string CodeGenTileLangCUDA::Finish() {
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
- stream << "#pragma unroll\n";
+ if (unroll_factor.count(op->loop_var.get())) {
+ stream << "#pragma unroll "
+ << PrintExpr(unroll_factor[op->loop_var.get()]) << "\n";
+ } else {
+ stream << "#pragma unroll\n";
+ }
}
std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
@@ -2661,7 +2666,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
return;
+ } else if (op->attr_key == "pragma_unroll_factor") {
+ const IntImmNode *factor = op->value.as();
+ ICHECK(factor);
+ unroll_factor[op->node.as()] = Downcast(factor);
}
+
CodeGenC::VisitStmt_(op);
}
diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h
index 6f229f11d..11c0ad081 100644
--- a/src/target/codegen_cuda.h
+++ b/src/target/codegen_cuda.h
@@ -140,6 +140,7 @@ class CodeGenTileLangCUDA final : public CodeGenC {
std::unordered_map fragment_shapes;
std::unordered_map fragment_layouts;
+ std::unordered_map unroll_factor;
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
void PrintWmmaScope(const std::string &scope, DataType t,
diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h
index b92fc73bf..bf2a5100b 100644
--- a/src/tl_templates/cuda/common.h
+++ b/src/tl_templates/cuda/common.h
@@ -127,6 +127,16 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
return result;
}
+TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0,
+ short z1, short w0, short w1) {
+ int4_t result;
+ *((short2 *)&result.x) = make_short2(x0, x1);
+ *((short2 *)&result.y) = make_short2(y0, y1);
+ *((short2 *)&result.z) = make_short2(z0, z1);
+ *((short2 *)&result.w) = make_short2(w0, w1);
+ return result;
+}
+
// Pack eight int values.
TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0,
int z1, int w0, int w1) {
diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h
index 020cb1f16..3f8ce5e6b 100644
--- a/src/tl_templates/cuda/debug.h
+++ b/src/tl_templates/cuda/debug.h
@@ -108,6 +108,16 @@ __device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
PrintTraits::print_buffer(msg, buf_name, index, var);
}
+template <>
+__device__ void debug_print_buffer_value(const char *msg,
+ const char *buf_name,
+ int index, uint16_t var) {
+ printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
+ "index=%d, dtype=uint16_t value=%u\n",
+ msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
+ threadIdx.z, buf_name, index, (uint32_t)var);
+}
+
TL_DEVICE void device_assert(bool cond) { assert(cond); }
TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
diff --git a/src/transform/atomicadd_vectorize.h b/src/transform/atomicadd_vectorize.h
index a55bc0f4a..627dc895f 100644
--- a/src/transform/atomicadd_vectorize.h
+++ b/src/transform/atomicadd_vectorize.h
@@ -11,7 +11,6 @@
#include "../op/builtin.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
-#include "atomicadd_vectorize.h"
#include "common/loop_vectorization_utils.h"
#include
#include
diff --git a/testing/python/analysis/test_tilelang_nested_loop_checker.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py
index b572a707a..d3c2ec20e 100644
--- a/testing/python/analysis/test_tilelang_nested_loop_checker.py
+++ b/testing/python/analysis/test_tilelang_nested_loop_checker.py
@@ -550,5 +550,178 @@ def test_mixed_pp():
run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1])
+"""
+TiledOp in a T.Parallel is also not permitted.
+"""
+
+
+def matmul_with_parallel(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ in_dtype,
+ out_dtype,
+ accum_dtype,
+ threads,
+ order,
+ stage,
+):
+ A_shape = (M, K)
+ B_shape = (K, N)
+ A_shared_shape = (block_M, block_K)
+ B_shared_shape = (block_K, block_N)
+
+ @T.prim_func
+ def main(
+ A: T.Tensor(A_shape, in_dtype),
+ B: T.Tensor(B_shape, in_dtype),
+ C: T.Tensor((M, N), out_dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+ A_shared = T.alloc_shared(A_shared_shape, in_dtype)
+ B_shared = T.alloc_shared(B_shared_shape, in_dtype)
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
+ T.clear(C_local)
+ for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
+ for i, j in T.Parallel(block_M, block_K):
+ A_shared[i, j] = A[by * block_M + i, k * block_K + j]
+ for i, j in T.Parallel(block_K, block_N):
+ B_shared[i, j] = B[k * block_K + i, bx * block_N + j]
+
+ # T.copy(A[by * block_M, k * block_K], A_shared)
+ # T.copy(B[k * block_K, bx * block_N], B_shared)
+
+ for _ in T.Parallel(1):
+ T.gemm(A_shared, B_shared, C_local, False, False)
+ T.copy(C_local, C[by * block_M, bx * block_N])
+
+ return main
+
+
+def run_gemm_tiled_op_with_parallel(
+ order,
+ stage,
+):
+ M = 1024
+ N = 1024
+ K = 1024
+ block_M = 128
+ block_N = 128
+ block_K = 32
+ in_dtype = "float16"
+ out_dtype = "float16"
+ dtypeAccum = "float32"
+ num_threads = 128
+
+ program = matmul_nested_pipa(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ num_threads,
+ order,
+ stage,
+ )
+
+ kernel = tilelang.compile(
+ program,
+ out_idx=[2],
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ })
+ profiler = kernel.get_profiler()
+
+ def ref_program(A, B):
+ import torch
+
+ if in_dtype == "float32":
+ # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
+ # float32 automatically, -0x1000 meas
+ A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
+ B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
+ C = torch.matmul(A.to(torch.float), B.to(torch.float))
+ C = C.to(torch.__getattribute__(out_dtype))
+ return C
+
+ profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
+
+ program1 = matmul_with_parallel(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ num_threads,
+ order,
+ stage,
+ )
+ with pytest.raises(ValueError):
+ tilelang.compile(
+ program1,
+ out_idx=[2],
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ })
+
+
+@tilelang.jit(out_idx=[1])
+def tir_op_with_parallel(length=256, block=16, dtype="float32"):
+
+ @T.prim_func
+ def main(
+ A: T.Tensor((length,), dtype),
+ B: T.Tensor((length,), dtype),
+ ):
+ with T.Kernel(1, threads=length) as _:
+ for i in T.Parallel(length // block):
+ for j in T.Parallel(block):
+ B[i * block + j] = T.max(A[i * block + j], 0.0)
+
+ return main
+
+
+@tilelang.jit(out_idx=[1])
+def customize_op_with_parallel(length=256, block=16, dtype="float32"):
+
+ @T.prim_func
+ def main(
+ A: T.Tensor((length,), dtype),
+ B: T.Tensor((length,), dtype),
+ ):
+ with T.Kernel(1, threads=length) as _:
+ for i in T.Parallel(length // block):
+ for j in T.Parallel(block):
+ B[i * block + j] = A[i * block + j]
+ T.atomic_add(B[i * block + j], 1.0)
+
+ return main
+
+
+def test_tiled_op_with_parallel():
+ run_gemm_tiled_op_with_parallel(order=[0, 1, 2], stage=[0, 0, 1])
+
+ kernel1 = tir_op_with_parallel(length=256, block=16)
+ data = _require_cuda_tensor((256,), torch.float32)
+ result1 = kernel1(data)
+ torch.testing.assert_close(result1, torch.relu(data), atol=1e-5, rtol=1e-5)
+ kernel2 = customize_op_with_parallel(length=256, block=16)
+ result2 = kernel2(data)
+ torch.testing.assert_close(result2, data + 1, atol=1e-5, rtol=1e-5)
+
+
if __name__ == "__main__":
tilelang.testing.main()
diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
index b4509fadc..13135d416 100644
--- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
+++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
@@ -52,7 +52,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
- is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
+ is_float8 = in_dtype in [
+ "float8_e4m3",
+ "float8_e5m2",
+ "float8_e4m3fn",
+ "float8_e5m2fnuz",
+ ]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
index 34def174d..46f4e123a 100644
--- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
+++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
@@ -51,7 +51,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
- is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
+ is_float8 = in_dtype in [
+ "float8_e4m3",
+ "float8_e5m2",
+ "float8_e4m3fn",
+ "float8_e5m2fnuz",
+ ]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
index da2e12cdc..6e20754eb 100644
--- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
+++ b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
@@ -52,7 +52,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
- is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
+ is_float8 = in_dtype in [
+ "float8_e4m3",
+ "float8_e5m2",
+ "float8_e4m3fn",
+ "float8_e5m2fnuz",
+ ]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
diff --git a/testing/python/language/test_tilelang_language_unroll.py b/testing/python/language/test_tilelang_language_unroll.py
new file mode 100644
index 000000000..1796302e3
--- /dev/null
+++ b/testing/python/language/test_tilelang_language_unroll.py
@@ -0,0 +1,37 @@
+import tilelang.testing
+from tilelang import tvm as tvm
+from tilelang import language as T
+
+
+def test_unroll_with_step():
+
+ @T.prim_func
+ def main(A_ptr: T.handle):
+ A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
+
+ for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
+ for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
+ for i in T.unroll(0, 16, step=4):
+ A[0, i] = 1.0
+
+ kernel = tilelang.compile(main, target="cuda")
+ assert "#pragma unroll" in kernel.get_kernel_source()
+
+
+def test_unroll_with_unroll_factor():
+
+ @T.prim_func
+ def main(A_ptr: T.handle):
+ A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
+
+ for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
+ for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
+ for i in T.unroll(0, 16, unroll_factor=4):
+ A[0, i] = 1.0
+
+ kernel = tilelang.compile(main, target="cuda")
+ assert "#pragma unroll 4" in kernel.get_kernel_source()
+
+
+if __name__ == "__main__":
+ tilelang.testing.main()
diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
index 74b9729f6..cefe986a0 100644
--- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
+++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
@@ -2,28 +2,46 @@
import tilelang
import tilelang.testing
-from tilelang.utils.sparse import compress, randn_semi_sparse
-from tilelang.layout import make_metadata_layout
-
-torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
-torch.manual_seed(42)
-
-STR_TO_TYPE = {
- 'float32': torch.float32,
- "float16": torch.float16,
- "bfloat16": torch.bfloat16,
- "float8_e4m3": torch.float8_e4m3fn,
- "int8": torch.int8,
- "int32": torch.int32,
-}
-
-SPARSITY_MAP = {
- # 'float32': (1, 2), # not supported for now
- torch.float16: (2, 4),
- torch.bfloat16: (2, 4),
- torch.float8_e4m3fn: (2, 4),
- torch.int8: (2, 4),
-}
+from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
+from tilelang.layout import make_cutlass_metadata_layout
+from tilelang.utils.tensor import torch_assert_close, map_torch_type
+from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
+
+torch.backends.cuda.matmul.allow_tf32 = False
+# torch.manual_seed(42) # only enable when debugging
+
+
+def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
+ is_8bit = "8" in in_dtype
+ is_unsigned = "uint" in in_dtype
+ is_int = "int" in in_dtype
+ if is_int:
+ if is_8bit:
+ low, high = (0, 4) if is_unsigned else (-2, 2)
+ else:
+ low, high = (0, 128) if is_unsigned else (-64, 64)
+ A = randint_semi_sparse(
+ M,
+ K,
+ low=low,
+ high=high,
+ dtype=map_torch_type(in_dtype),
+ device='cuda',
+ transposed=trans_A)
+ B = torch.randint(
+ size=(N, K) if trans_B else (K, N),
+ low=low,
+ high=high,
+ dtype=map_torch_type(in_dtype),
+ device='cuda')
+ else:
+ A = randn_semi_sparse(
+ M, K, dtype=torch.float32, device='cuda',
+ transposed=trans_A).to(map_torch_type(in_dtype))
+ B = torch.randn(
+ (N, K) if trans_B else (K, N), device='cuda',
+ dtype=torch.float32).to(map_torch_type(in_dtype))
+ return A, B
def matmul_sp_sm90(
@@ -60,21 +78,17 @@ def main(
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
- C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
+ C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
- make_metadata_layout(
- E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
+ make_cutlass_metadata_layout(
+ E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
E_shared:
- make_metadata_layout(
- E_shared,
- mma_dtype="float16",
- arch="9.0",
- backend="cutlass",
- block_k=block_K),
+ make_cutlass_metadata_layout(
+ E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
})
T.disable_warp_group_reg_alloc()
- T.clear(C_local)
+ T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
@@ -85,8 +99,8 @@ def main(
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
- T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B)
- T.copy(C_local, C[by * block_M, bx * block_N])
+ T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
+ T.copy(C_frag, C[by * block_M, bx * block_N])
return main
@@ -107,7 +121,8 @@ def matmul_sp_sm80(
trans_B,
):
is_8_bit = "8" in in_dtype
- E_factor = 32 if is_8_bit else 16
+ metadata_dtype = 'int32' if is_8_bit else 'int16'
+ E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
@@ -118,22 +133,18 @@ def matmul_sp_sm80(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
- E: T.Tensor((M, K // E_factor), 'int32' if is_8_bit else 'int16'),
+ E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
- E_shared = T.alloc_shared((block_M, block_K // E_factor),
- 'int32' if is_8_bit else 'int16')
+ E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
- E:
- make_metadata_layout(E, mma_dtype="float16", backend="cutlass", arch="8.0"),
- E_shared:
- make_metadata_layout(
- E_shared, mma_dtype="float16", backend="cutlass", arch="8.0"),
+ E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
+ E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
@@ -181,19 +192,14 @@ def run_gemm_sp(
kernel,
out_idx=[-1],
)
- A = randn_semi_sparse(M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', transposed=trans_A)
- if trans_B:
- B = torch.randn((N, K), device='cuda', dtype=torch.float32)
- else:
- B = torch.randn((K, N), device='cuda', dtype=torch.float32)
-
- if "float8" in in_dtype or "int8" in in_dtype:
- A = normalize(A.float())
- B = normalize(B.float())
-
- A = A.to(STR_TO_TYPE[in_dtype])
- B = B.to(STR_TO_TYPE[in_dtype])
-
+ A, B = generate_dense_input(
+ M=M,
+ N=N,
+ K=K,
+ trans_A=trans_A,
+ trans_B=trans_B,
+ in_dtype=in_dtype,
+ )
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
C_sp = kernel(A_sparse, E, B)
@@ -206,14 +212,22 @@ def _matmul(A, B):
if "float8" in in_dtype or "int8" in in_dtype:
A = A.to(torch.float32)
B = B.to(torch.float32)
- return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype])
+ return torch.matmul(A, B)
C = _matmul(A, B)
+
if 'float8' in in_dtype:
diff = calc_diff(C_sp, C)
assert diff < 1e-3, f"{diff=}"
else:
- torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3)
+ torch_assert_close(
+ C_sp.to(torch.float32),
+ C.to(torch.float32),
+ rtol=1e-3,
+ atol=1e-3,
+ base_name="tilelang_sp",
+ ref_name="ref_dense",
+ )
print("pass")
diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
new file mode 100644
index 000000000..a82c29f38
--- /dev/null
+++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
@@ -0,0 +1,666 @@
+from tilelang import tvm as tvm
+from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
+from tilelang.utils.tensor import torch_assert_close, map_torch_type
+from tilelang.layout import make_cutlass_metadata_layout
+from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
+
+import tilelang.testing
+import torch
+
+
+def matmul(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ accum_dtype,
+ metadata_dtype,
+ E_factor,
+ num_stages,
+ threads,
+):
+ A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
+ B_shape = (N, K) if trans_B else (K, N)
+ A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
+ B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
+
+ import tilelang.language as T
+
+ @T.prim_func
+ def main(
+ A_sparse: T.Tensor(A_sparse_shape, in_dtype),
+ E: T.Tensor((M, K // E_factor), metadata_dtype),
+ B: T.Tensor(B_shape, in_dtype),
+ C: T.Tensor((M, N), out_dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+ A_shared = T.alloc_shared(A_shared_shape, in_dtype)
+ B_shared = T.alloc_shared(B_shared_shape, in_dtype)
+ E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
+ C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
+ T.annotate_layout({
+ E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
+ E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
+ })
+ T.clear(C_frag)
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
+ T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
+ if trans_A:
+ T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
+ else:
+ T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
+ if trans_B:
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ else:
+ T.copy(B[k * block_K, bx * block_N], B_shared)
+ T.gemm_sp_v2(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
+ T.copy(C_frag, C[by * block_M, bx * block_N])
+
+ return main
+
+
+def run_gemm_ss(
+ M,
+ N,
+ K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ block_M,
+ block_N,
+ block_K,
+ num_stages=3,
+ num_threads=128,
+):
+ metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
+ program = matmul(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ metadata_dtype,
+ SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
+ num_stages,
+ num_threads,
+ )
+
+ kernel = tilelang.compile(
+ program,
+ out_idx=[3],
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ })
+ A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
+
+ A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
+ C_sp = kernel(A_sparse, E, B)
+
+ def _matmul(A, B):
+ if trans_A:
+ A = A.T
+ if trans_B:
+ B = B.T
+ A = A.to(torch.float32)
+ B = B.to(torch.float32)
+ return torch.matmul(A, B)
+
+ C = _matmul(A, B)
+
+ torch_assert_close(
+ C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
+ C.to(map_torch_type(out_dtype)).to(torch.float32),
+ rtol=1e-3,
+ atol=1e-3,
+ base_name="tilelang_sp",
+ ref_name="ref_dense",
+ )
+ print("pass")
+
+
+def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
+ is_8bit = "8" in in_dtype
+ is_unsigned = "uint" in in_dtype
+ is_int = "int" in in_dtype
+ if is_int:
+ if is_8bit:
+ low, high = (0, 4) if is_unsigned else (-2, 2)
+ else:
+ low, high = (0, 128) if is_unsigned else (-64, 64)
+ A = randint_semi_sparse(
+ M,
+ K,
+ low=low,
+ high=high,
+ dtype=map_torch_type(in_dtype),
+ device='cuda',
+ transposed=trans_A)
+ B = torch.randint(
+ size=(N, K) if trans_B else (K, N),
+ low=low,
+ high=high,
+ dtype=map_torch_type(in_dtype),
+ device='cuda')
+ else:
+ A = randn_semi_sparse(
+ M, K, dtype=map_torch_type(in_dtype), device='cuda', transposed=trans_A)
+ B = torch.randn(
+ (N, K) if trans_B else (K, N), device='cuda',
+ dtype=torch.float32).to(map_torch_type(in_dtype))
+ return A, B
+
+
+def test_gemm_ss():
+ # More test case can be found in kernel/test_tilelang_kernel_gemm.py
+ # GEMM tests for float16
+ # TODO: support transposed A compressor
+ run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2)
+ run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2)
+ run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2)
+ run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2)
+
+ # n8 test
+ run_gemm_ss(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
+
+ # int8 test
+ run_gemm_ss(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2)
+ run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
+
+ # float8 tests
+ run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64,
+ 2)
+ run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
+
+ # tfloat32 test
+ # run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
+
+
+def matmul_rs(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ accum_dtype,
+ metadata_dtype,
+ E_factor,
+ num_stages,
+ threads,
+):
+ A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
+ B_shape = (N, K) if trans_B else (K, N)
+ A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
+ B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
+ A_frag_shape = A_shared_shape
+
+ import tilelang.language as T
+
+ @T.prim_func
+ def main(
+ A_sparse: T.Tensor(A_sparse_shape, in_dtype),
+ E: T.Tensor((M, K // E_factor), metadata_dtype),
+ B: T.Tensor(B_shape, in_dtype),
+ C: T.Tensor((M, N), out_dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+ A_shared = T.alloc_shared(A_shared_shape, in_dtype)
+ B_shared = T.alloc_shared(B_shared_shape, in_dtype)
+ E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
+ A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
+ C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
+ T.annotate_layout({
+ A_shared: tilelang.layout.make_swizzled_layout(A_shared),
+ E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
+ E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
+ })
+ T.clear(C_frag)
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
+ T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
+ if trans_A:
+ T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
+ else:
+ T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
+ if trans_B:
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ else:
+ T.copy(B[k * block_K, bx * block_N], B_shared)
+ T.copy(A_shared, A_frag)
+ T.gemm_sp_v2(A_frag, E_shared, B_shared, C_frag, trans_A, trans_B)
+ T.copy(C_frag, C[by * block_M, bx * block_N])
+
+ return main
+
+
+def run_gemm_rs(
+ M,
+ N,
+ K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ block_M,
+ block_N,
+ block_K,
+ num_stages=3,
+ num_threads=128,
+):
+ metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
+ program = matmul_rs(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ metadata_dtype,
+ SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
+ num_stages,
+ num_threads,
+ )
+
+ kernel = tilelang.compile(
+ program,
+ out_idx=[3],
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ })
+ A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
+ A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
+ C_sp = kernel(A_sparse, E, B)
+
+ def _matmul(A, B):
+ if trans_A:
+ A = A.T
+ if trans_B:
+ B = B.T
+ A = A.to(torch.float32)
+ B = B.to(torch.float32)
+ return torch.matmul(A, B)
+
+ C = _matmul(A, B)
+
+ torch_assert_close(
+ C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
+ C.to(map_torch_type(out_dtype)).to(torch.float32),
+ rtol=1e-3,
+ atol=1e-3,
+ base_name="tilelang_sp",
+ ref_name="ref_dense",
+ )
+ print("pass")
+
+
+def test_gemm_rs():
+ # GEMM tests for float16
+ run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
+
+ # n8 tests
+ run_gemm_rs(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
+
+ # int8 tests
+ run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
+
+ # float8 tests
+ run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
+
+ # float32 tests
+ # run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
+
+
+def matmul_sr(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ accum_dtype,
+ metadata_dtype,
+ E_factor,
+ num_stages,
+ threads,
+):
+ A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
+ B_shape = (N, K) if trans_B else (K, N)
+ A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
+ B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
+ B_frag_shape = B_shared_shape
+
+ import tilelang.language as T
+
+ @T.prim_func
+ def main(
+ A_sparse: T.Tensor(A_sparse_shape, in_dtype),
+ E: T.Tensor((M, K // E_factor), metadata_dtype),
+ B: T.Tensor(B_shape, in_dtype),
+ C: T.Tensor((M, N), out_dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+ A_shared = T.alloc_shared(A_shared_shape, in_dtype)
+ B_shared = T.alloc_shared(B_shared_shape, in_dtype)
+ E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
+ B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
+ C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
+ T.annotate_layout({
+ B_shared: tilelang.layout.make_swizzled_layout(B_shared),
+ E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
+ E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
+ })
+ T.clear(C_frag)
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
+ T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
+ if trans_A:
+ T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
+ else:
+ T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
+ if trans_B:
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ else:
+ T.copy(B[k * block_K, bx * block_N], B_shared)
+ T.copy(B_shared, B_frag)
+ T.gemm_sp_v2(A_shared, E_shared, B_frag, C_frag, trans_A, trans_B)
+ T.copy(C_frag, C[by * block_M, bx * block_N])
+
+ return main
+
+
+def run_gemm_sr(
+ M,
+ N,
+ K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ block_M,
+ block_N,
+ block_K,
+ num_stages=3,
+ num_threads=128,
+):
+ metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
+ program = matmul_sr(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ metadata_dtype,
+ SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
+ num_stages,
+ num_threads,
+ )
+
+ kernel = tilelang.compile(
+ program,
+ out_idx=[3],
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ })
+ A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
+ A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
+ C_sp = kernel(A_sparse, E, B)
+
+ def _matmul(A, B):
+ if trans_A:
+ A = A.T
+ if trans_B:
+ B = B.T
+ A = A.to(torch.float32)
+ B = B.to(torch.float32)
+ return torch.matmul(A, B)
+
+ C = _matmul(A, B)
+
+ torch_assert_close(
+ C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
+ C.to(map_torch_type(out_dtype)).to(torch.float32),
+ rtol=1e-3,
+ atol=1e-3,
+ base_name="tilelang_sp",
+ ref_name="ref_dense",
+ )
+ print("pass")
+
+
+def test_gemm_sr():
+ # GEMM tests for float16
+ run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
+
+ # n8 tests
+ run_gemm_sr(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
+
+ # int8 tests
+ run_gemm_sr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2)
+ run_gemm_sr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2)
+ run_gemm_sr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_sr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
+
+ # float8 tests
+ run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
+
+ # float32 tests
+ # run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
+
+
+def matmul_rr(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ accum_dtype,
+ metadata_dtype,
+ E_factor,
+ num_stages,
+ threads,
+):
+ A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
+ B_shape = (N, K) if trans_B else (K, N)
+ A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
+ B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
+ A_frag_shape = A_shared_shape
+ B_frag_shape = B_shared_shape
+
+ import tilelang.language as T
+
+ @T.prim_func
+ def main(
+ A_sparse: T.Tensor(A_sparse_shape, in_dtype),
+ E: T.Tensor((M, K // E_factor), metadata_dtype),
+ B: T.Tensor(B_shape, in_dtype),
+ C: T.Tensor((M, N), out_dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+ A_shared = T.alloc_shared(A_shared_shape, in_dtype)
+ B_shared = T.alloc_shared(B_shared_shape, in_dtype)
+ E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
+ A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
+ B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
+ C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
+ T.annotate_layout({
+ A_shared: tilelang.layout.make_swizzled_layout(A_shared),
+ B_shared: tilelang.layout.make_swizzled_layout(B_shared),
+ E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
+ E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
+ })
+ T.clear(C_frag)
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
+ T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
+ if trans_A:
+ T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
+ else:
+ T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
+ if trans_B:
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ else:
+ T.copy(B[k * block_K, bx * block_N], B_shared)
+ T.copy(A_shared, A_frag)
+ T.copy(B_shared, B_frag)
+ T.gemm_sp_v2(A_frag, E_shared, B_frag, C_frag, trans_A, trans_B)
+ T.copy(C_frag, C[by * block_M, bx * block_N])
+
+ return main
+
+
+def run_gemm_rr(
+ M,
+ N,
+ K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ block_M,
+ block_N,
+ block_K,
+ num_stages=3,
+ num_threads=128,
+):
+ metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
+ program = matmul_rr(
+ M,
+ N,
+ K,
+ block_M,
+ block_N,
+ block_K,
+ trans_A,
+ trans_B,
+ in_dtype,
+ out_dtype,
+ dtypeAccum,
+ metadata_dtype,
+ SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
+ num_stages,
+ num_threads,
+ )
+
+ kernel = tilelang.compile(
+ program,
+ out_idx=[3],
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ })
+ A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
+ A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
+ C_sp = kernel(A_sparse, E, B)
+
+ def _matmul(A, B):
+ if trans_A:
+ A = A.T
+ if trans_B:
+ B = B.T
+ A = A.to(torch.float32)
+ B = B.to(torch.float32)
+ return torch.matmul(A, B)
+
+ C = _matmul(A, B)
+
+ torch_assert_close(
+ C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
+ C.to(map_torch_type(out_dtype)).to(torch.float32),
+ rtol=1e-3,
+ atol=1e-3,
+ base_name="tilelang_sp",
+ ref_name="ref_dense",
+ )
+ print("pass")
+
+
+def test_gemm_rr():
+ # GEMM tests for float16
+ run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
+ run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
+ # n8 tests
+ run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2)
+ run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2)
+
+ # int8 tests
+ run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
+ run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
+
+ # float8 tests
+ run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
+
+ # float32 tests
+ # run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
+ # run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
+
+
+if __name__ == "__main__":
+ tilelang.testing.main()
diff --git a/tilelang/analysis/ast_printer.py b/tilelang/analysis/ast_printer.py
index c54ec5cf9..e634e0271 100644
--- a/tilelang/analysis/ast_printer.py
+++ b/tilelang/analysis/ast_printer.py
@@ -14,7 +14,7 @@ def pre_visit(statement: tir.Stmt) -> None:
Pre-order visitor to print all visited statements.
"""
- print(f"Visiting statement: {type(statement)}")
+ print(f"Visiting statement: {type(statement)}, {statement}")
def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc:
new_body = ir_transform(func.body, pre_visit, None)
diff --git a/tilelang/analysis/nested_loop_checker.py b/tilelang/analysis/nested_loop_checker.py
index 7a0d94daa..eff0fc2db 100644
--- a/tilelang/analysis/nested_loop_checker.py
+++ b/tilelang/analysis/nested_loop_checker.py
@@ -1,6 +1,7 @@
from tvm import tir
from tvm.tir import (
For,
+ Call,
PrimFunc,
PyStmtExprVisitor,
)
@@ -17,6 +18,12 @@ def is_pipelined_for(op: For) -> bool:
return any(key in op.annotations for key in anno_keys)
+def is_tile_op(op: Call) -> bool:
+ """Check if a call is a tile-op"""
+
+ return op.op.get_attr("TLOpBuilder") is not None
+
+
@tir.functor.visitor
class _NestedLoopCheckVisitor(PyStmtExprVisitor):
@@ -39,7 +46,7 @@ def visit_for_(self, op: For) -> None:
"Nested parallel loops are not allowed. "
"Please check your loop structure.")
self.in_parallel_context = True
- self.visit_stmt(child)
+ super().visit_for_(op)
self.in_parallel_context = False
return
elif is_pipelined_for(op):
@@ -48,7 +55,14 @@ def visit_for_(self, op: For) -> None:
"Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure.")
- self.visit_stmt(op.body)
+ super().visit_for_(op)
+
+ def visit_call_(self, op: Call) -> None:
+ if self.in_parallel_context and is_tile_op(op):
+ raise ValueError("[Tilelang Semantic Check] "
+ "Only elementwise operations are allowed inside a parallel loop. " \
+ f"Got a tile-op \"{op.op}\"."
+ )
def NestedLoopChecker():
diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py
index 47ac888cf..9b2fca2c3 100644
--- a/tilelang/autotuner/tuner.py
+++ b/tilelang/autotuner/tuner.py
@@ -325,7 +325,7 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
key = self.generate_cache_key(parameters, extra_parameters)
with self._lock:
- if env.is_cache_enabled():
+ if env.is_cache_enabled() and not env.is_autotune_cache_disabled():
# First check in-memory cache
if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \
@@ -601,7 +601,7 @@ def inner(**config_arg):
logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
- if env.is_cache_enabled():
+ if env.is_cache_enabled() and not env.is_autotune_cache_disabled():
self._save_result_to_disk(key, autotuner_result)
self._memory_cache[key] = autotuner_result
diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py
index 0d55cbf7d..0e6a19ba1 100644
--- a/tilelang/contrib/nvcc.py
+++ b/tilelang/contrib/nvcc.py
@@ -80,8 +80,8 @@ def compile_cuda(code,
file_target = path_target if path_target else temp_target
cmd = [get_nvcc_compiler()]
cmd += [f"--{target_format}", "-O3"]
- if kernels_output_dir is not None:
- cmd += ["-lineinfo"]
+ # Always include line info for better profiling and mapping
+ cmd += ["-lineinfo"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py
index dfa8050a3..1a98c8937 100644
--- a/tilelang/engine/phase.py
+++ b/tilelang/engine/phase.py
@@ -76,10 +76,8 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
# Debug
# tilelang.analysis.ASTPrinter()(mod)
-
# Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod)
-
# Check if there are any invalid symbolic T.Parallel + fragment access.
tilelang.analysis.FragmentLoopChecker()(mod)
diff --git a/tilelang/env.py b/tilelang/env.py
index 39d9e722e..b1697ef51 100644
--- a/tilelang/env.py
+++ b/tilelang/env.py
@@ -196,12 +196,6 @@ def __set__(self, instance, value):
# os.environ[self.key] = value
-# Cache control API (wrap CacheState)
-enable_cache = CacheState.enable
-disable_cache = CacheState.disable
-is_cache_enabled = CacheState.is_enabled
-
-
# Utility function for environment variables with defaults
# Assuming EnvVar and CacheState are defined elsewhere
class Environment:
@@ -234,13 +228,18 @@ class Environment:
# Kernel Build options
TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION",
"1") # print kernel name on compile
- TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set
+ TILELANG_DISABLE_CACHE = EnvVar(
+ "TILELANG_DISABLE_CACHE",
+ "0") # disable kernel cache, usually for unit testing / debugging, high priority
+ TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE",
+ "0") # DEPRECATED! clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0")
# Auto-tuning settings
+ TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0")
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
"0.9") # percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS",
@@ -267,7 +266,7 @@ def _initialize_torch_cuda_arch_flags(self) -> None:
# Cache control API (wrap CacheState)
def is_cache_enabled(self) -> bool:
- return CacheState.is_enabled()
+ return not self.is_cache_globally_disabled() and CacheState.is_enabled()
def enable_cache(self) -> None:
CacheState.enable()
@@ -275,6 +274,12 @@ def enable_cache(self) -> None:
def disable_cache(self) -> None:
CacheState.disable()
+ def is_cache_globally_disabled(self) -> bool:
+ return self.TILELANG_DISABLE_CACHE.lower() in ("1", "true", "yes", "on")
+
+ def is_autotune_cache_disabled(self) -> bool:
+ return self.TILELANG_AUTO_TUNING_DISABLE_CACHE.lower() in ("1", "true", "yes", "on")
+
def is_print_on_compilation_enabled(self) -> bool:
return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on")
@@ -290,6 +295,11 @@ def use_gemm_v1(self) -> bool:
# Instantiate as a global configuration object
env = Environment()
+# Cache control API (wrap env, which is managed by CacheState and Environment Variables jointly)
+enable_cache = env.enable_cache # CacheState.enable
+disable_cache = env.disable_cache # CacheState.disable
+is_cache_enabled = env.is_cache_enabled # CacheState.is_enabled
+
# Export CUDA_HOME and ROCM_HOME, both are static variables
# after initialization.
CUDA_HOME = env.CUDA_HOME
diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py
index 449b6b943..f49b59569 100644
--- a/tilelang/intrinsics/mma_layout.py
+++ b/tilelang/intrinsics/mma_layout.py
@@ -151,12 +151,43 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
return row, col
+def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id):
+ """
+ groupID = %laneid >> 2
+ threadID_in_group = %laneid % 4
+
+ row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
+ groupID + 8 Otherwise
+
+ col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
+ (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
+ """
+ row = (thread_id // 4) + 8 * (local_id % 4 // 2)
+ col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4)
+ return row, col
+
+
def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
row = 8 * (local_id // 8) + (thread_id // 4)
col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4)
return row, col
+def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
+ """
+ groupID = %laneid >> 2
+ threadID_in_group = %laneid % 4
+
+ row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
+ (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
+
+ col = groupID
+ """
+ col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8
+ row = (thread_id // 4) + 8 * (local_id // 4)
+ return row, col
+
+
def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8)
diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py
index 6e49b0582..5811eb534 100644
--- a/tilelang/intrinsics/mma_macro_generator.py
+++ b/tilelang/intrinsics/mma_macro_generator.py
@@ -22,8 +22,10 @@
shared_16x32_to_mma_32x16_layout_sr_b,
mma_load_a_32x4_to_shared_16x8_layout,
mma_load_b_32x4_to_shared_16x8_layout,
+ mma_load_b_32x8_to_shared_16x16_layout,
mma_load_a_32x16_to_shared_16x32_layout,
mma_load_b_32x16_to_shared_16x32_layout,
+ mma_load_a_32x8_to_shared_16x16_layout,
)
lift = convert
@@ -291,6 +293,8 @@ def mma_load_layout(i, j):
if not ldmatrix_available:
if DataType(a_dtype).bits == 8:
mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout
+ elif DataType(a_dtype).bits == 16:
+ mma_load_layout = mma_load_a_32x8_to_shared_16x16_layout
elif DataType(a_dtype).bits == 32:
mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout
else:
@@ -417,6 +421,8 @@ def mma_load_layout(i, j):
if not ldmatrix_available:
if DataType(b_dtype).bits == 8:
mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout
+ elif DataType(b_dtype).bits == 16:
+ mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout
elif DataType(b_dtype).bits == 32:
mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout
else:
diff --git a/tilelang/intrinsics/mma_sp_layout.py b/tilelang/intrinsics/mma_sp_layout.py
new file mode 100644
index 000000000..bae86bf45
--- /dev/null
+++ b/tilelang/intrinsics/mma_sp_layout.py
@@ -0,0 +1,190 @@
+from tvm import DataType
+from typing import Literal
+
+from tilelang.intrinsics.mma_layout import (
+ mma_load_a_32x4_to_shared_16x8_layout,
+ mma_load_a_32x16_to_shared_16x32_layout,
+ mma_load_a_32x8_to_shared_16x16_layout,
+ shared_16x8_to_mma_32x4_layout_sr_a,
+ shared_16x16_to_mma_32x8_layout_sr_a,
+ shared_16x32_to_mma_32x16_layout_sr_a,
+)
+
+
+def shared_16x16_to_mma_sp_layout_sr_a(i, j):
+ return shared_16x8_to_mma_32x4_layout_sr_a(i, j)
+
+
+def shared_16x16_to_mma_sp_layout_sr_b(i, j):
+ thread_id = 4 * (i % 8) + (j % 4)
+ return thread_id, 4 * (i // 8) + (j // 4)
+
+
+def shared_16x32_to_mma_sp_layout_sr_a(i, j):
+ return shared_16x16_to_mma_32x8_layout_sr_a(i, j)
+
+
+def shared_16x32_to_mma_sp_layout_sr_b(i, j):
+ thread_id = 4 * (i % 8) + (j % 8) // 2
+ return thread_id, 8 * (i // 8) + (j // 8) * 2 + (j % 2)
+
+
+def shared_16x64_to_mma_sp_layout_sr_a(i, j):
+ return shared_16x32_to_mma_32x16_layout_sr_a(i, j)
+
+
+def shared_16x64_to_mma_sp_layout_sr_b(i, j):
+ thread_id = 4 * (i % 8) + (j % 16) // 4
+ return thread_id, 16 * (i // 8) + (j // 16) * 4 + j % 4
+
+
+def mma_sp_load_a_32x4_to_shared_16x16_layout(thread_id, local_id):
+ return mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id)
+
+
+def mma_sp_load_a_32x8_to_shared_16x32_layout(thread_id, local_id):
+ return mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id)
+
+
+def mma_sp_load_a_32x16_to_shared_16x64_layout(thread_id, local_id):
+ return mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id)
+
+
+def mma_sp_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
+ col = 4 * (local_id % 4) + (thread_id % 4)
+ row = 8 * (local_id // 4) + (thread_id // 4)
+ return row, col
+
+
+def mma_sp_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
+ col = (thread_id % 4) * 2 + (local_id % 2) + ((local_id % 8) // 2) * 8
+ row = (thread_id // 4) + 8 * (local_id // 8)
+ return row, col
+
+
+def mma_sp_load_b_32x32_to_shared_16x64_layout(thread_id, local_id):
+ col = (thread_id % 4) * 4 + (local_id % 4) + 16 * ((local_id % 16) // 4)
+ row = (thread_id // 4) + 8 * (local_id // 16)
+ return row, col
+
+
+def get_logical_id_32bit(thread_id: int) -> int:
+ return (thread_id // 4) * 2 + (thread_id % 4) % 2
+
+
+def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int,
+ local_id: int) -> tuple[int, int]:
+ logical_id = get_logical_id_32bit(thread_id)
+ row = logical_id // 4 + local_id * 8
+ col = logical_id % 4
+ return row, col
+
+
+def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int,
+ local_id: int) -> tuple[int, int]:
+ logical_id = get_logical_id_32bit(thread_id)
+ row = logical_id // 2 + local_id * 8
+ col = logical_id % 2
+ return row, col
+
+
+def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int,
+ local_id: int) -> tuple[int, int]:
+ return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(
+ thread_id, local_id) # same mapping for 16bit and 32bit
+
+
+def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int,
+ local_id: int) -> tuple[int, int]:
+ return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(
+ thread_id, local_id) # same mapping for 16bit and 32bit
+
+
+def get_logical_id_8bit(thread_id: int) -> int:
+ return thread_id
+
+
+def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int,
+ local_id: int) -> tuple[int, int]:
+ logical_id = get_logical_id_8bit(thread_id)
+ row = logical_id // 2 + local_id * 8
+ col = (logical_id % 4) // 2 * 4 + local_id
+ return row, col
+
+
+def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int,
+ local_id: int) -> tuple[int, int]:
+ logical_id = get_logical_id_8bit(thread_id)
+ row = logical_id // 2 + local_id * 8
+ col = (logical_id % 4) // 2 * 2 + local_id
+ return row, col
+
+
+def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int,
+ local_id: int) -> tuple[int, int]:
+ # local_id is always 0
+ logical_id = get_logical_id_8bit(thread_id)
+ row = logical_id // 4 + (logical_id % 2) * 8
+ col = (logical_id % 4) // 2
+ return row, col
+
+
+def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
+ row = (local_id // 4) * 8 + thread_id % 8
+ col = (thread_id // 8) * 4 + local_id % 4
+ return row, col
+
+
+def ldmatrix_32x16_to_shared_32x16_layout(thread_id, local_id):
+ row = thread_id
+ col = local_id % 8 + 8 * (local_id // 8)
+ return row, col
+
+
+def ldmatrix_trans_32x16_to_shared_16x32_layout(thread_id, local_id):
+ row = 8 * (local_id // 8) + thread_id % 8
+ col = (thread_id // 8) * 8 + local_id % 8
+ return row, col
+
+
+def ldmatrix_trans_32x32_to_shared_shared_16x64_layout(thread_id, local_id):
+ row = (local_id // 16) * 8 + thread_id % 8
+ col = (thread_id // 8) * 16 + local_id % 16
+ return row, col
+
+
+def get_ldmatrix_offset_b(
+ matrix: Literal["B"],
+ row_idx,
+ col_idx,
+ stride,
+ dtype: Literal["float16", "int8"] = "float16",
+ transposed: bool = False,
+):
+ assert matrix == "B", "matrix should be B"
+ dtype_bits = DataType(dtype).bits
+ if dtype_bits == 32:
+ if transposed:
+ transform_func = ldmatrix_trans_32x8_to_shared_16x16_layout
+ new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
+ return new_row_idx * stride + new_col_idx
+ else:
+ raise ValueError("ldmatrix only supports B transposed for 32-bit dtype")
+ elif dtype_bits == 16:
+ transform_func = ldmatrix_32x16_to_shared_32x16_layout
+ transform_func_trans = ldmatrix_trans_32x16_to_shared_16x32_layout
+ if transposed:
+ new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
+ return new_row_idx * stride + new_col_idx
+ else:
+ new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
+ return new_row_idx * stride + new_col_idx
+ elif dtype_bits == 8:
+ if transposed:
+ transform_func = ldmatrix_trans_32x32_to_shared_shared_16x64_layout
+ new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
+ return new_row_idx * stride + new_col_idx
+ else:
+ raise ValueError("ldmatrix only supports B transposed for 8-bit dtype")
+ else:
+ raise ValueError(f"Unsupported dtype {dtype}")
diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/intrinsics/mma_sp_macro_generator.py
new file mode 100644
index 000000000..629d95d99
--- /dev/null
+++ b/tilelang/intrinsics/mma_sp_macro_generator.py
@@ -0,0 +1,864 @@
+from __future__ import annotations
+
+import tilelang.language as T
+from typing import Literal, Callable
+from tvm import DataType
+from tvm.tir import PrimExpr, IndexMap, Buffer, Var
+from tvm.runtime import convert
+from .utils import (
+ mma_store_index_map,
+ get_ldmatrix_offset,
+)
+from tilelang.utils import is_fragment
+
+from tilelang.intrinsics.mma_sp_layout import (
+ shared_16x16_to_mma_sp_layout_sr_a,
+ shared_16x16_to_mma_sp_layout_sr_b,
+ shared_16x32_to_mma_sp_layout_sr_a,
+ shared_16x32_to_mma_sp_layout_sr_b,
+ shared_16x64_to_mma_sp_layout_sr_a,
+ shared_16x64_to_mma_sp_layout_sr_b,
+ mma_sp_load_a_32x4_to_shared_16x16_layout,
+ mma_sp_load_a_32x8_to_shared_16x32_layout,
+ mma_sp_load_a_32x16_to_shared_16x64_layout,
+ mma_sp_load_b_32x8_to_shared_16x16_layout,
+ mma_sp_load_b_32x16_to_shared_16x32_layout,
+ mma_sp_load_b_32x32_to_shared_16x64_layout,
+ metadata_8bit_load_32x4_to_shared_16x4_layout_32bit,
+ metadata_16bit_load_32x2_to_shared_16x2_layout_32bit,
+ metadata_8bit_load_32x4_to_shared_16x4_layout_16bit,
+ metadata_16bit_load_32x2_to_shared_16x2_layout_16bit,
+ metadata_8bit_load_32x4_to_shared_16x4_layout_8bit,
+ metadata_16bit_load_32x2_to_shared_16x4_layout_8bit,
+ metadata_32bit_load_32x1_to_shared_16x2_layout_8bit,
+ get_ldmatrix_offset_b,
+)
+
+lift = convert
+
+
+class SparseTensorCoreIntrinEmitter:
+ """
+ To eliminate Python syntax within TIR Macro.
+ """
+
+ M_DIM = 16
+ SPARSE_FACTOR = 2 # 1:2 for tfloat12, 2:4 for 16-bit and 8-bit datatypes
+ SPARSE_SELECTOR = 0 # always use lower threads to provide metadata
+ # use lowercase as n_dim can be dynamic
+ # the smallest instructions can be m16n8k16, so the n_dim can also be 8
+ n_dim = 16
+ WARP_SIZE = 32
+ dtype_abbrv = {
+ "float16": "fp16",
+ "bfloat16": "bf16",
+ "float32": "fp32",
+ "int8": "int8",
+ "int32": "int32",
+ "float8_e4m3": "e4m3",
+ "float8_e5m2": "e5m2",
+ }
+
+ E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor
+ "float": {
+ "int16": 8,
+ "uint16": 8,
+ },
+ "float32": {
+ "int16": 8,
+ "uint16": 8,
+ },
+ "float16": {
+ "int8": 8,
+ "uint8": 8,
+ "int16": 16,
+ "uint16": 16,
+ "int32": 32,
+ "uint32": 32,
+ },
+ "bfloat16": {
+ "int8": 8,
+ "uint8": 8,
+ "int16": 16,
+ "uint16": 16,
+ "int32": 32,
+ "uint32": 32,
+ },
+ "int8": {
+ "int8": 8,
+ "uint8": 8,
+ "int16": 16,
+ "uint16": 16,
+ "int32": 32,
+ "uint32": 32,
+ },
+ "uint8": {
+ "int8": 8,
+ "uint8": 8,
+ "int16": 16,
+ "uint16": 16,
+ "int32": 32,
+ "uint32": 32,
+ },
+ "float8_e4m3": {
+ "int8": 8,
+ "uint8": 8,
+ "int16": 16,
+ "uint16": 16,
+ "int32": 32,
+ "uint32": 32,
+ },
+ "float8_e5m2": {
+ "int8": 8,
+ "uint8": 8,
+ "int16": 16,
+ "uint16": 16,
+ "int32": 32,
+ "uint32": 32,
+ },
+ }
+
+ E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads
+ "float32": 2,
+ "float16": 2, # 2 of 4 consecutive threads provides
+ "bfloat16": 2,
+ "int8": 1, # 4 of 4 consecutive threads provides
+ "uint8": 1,
+ "float8_e4m3": 1,
+ "float8_e5m2": 1,
+ }
+
+ # Represent the thread binding in the form of (tx, warp_n, warp_m)
+ is_m_first = False
+
+ def __init__(
+ self,
+ a_dtype: str = "float16",
+ e_dtype: str = "uint8",
+ b_dtype: str = "float16",
+ accum_dtype: str = "float16",
+ a_transposed: bool = False,
+ b_transposed: bool = False,
+ e_transposed: bool = False,
+ block_row_warps: int = 2,
+ block_col_warps: int = 2,
+ warp_row_tiles: int = 8,
+ warp_col_tiles: int = 8,
+ warp_k: int = 16,
+ reduce_k: int = 1,
+ num_elems_per_byte: int = 1,
+ is_m_first: bool = False,
+ thread_var: Var | None = None,
+ ):
+ self.a_dtype = a_dtype
+ self.e_dtype = e_dtype
+ self.b_dtype = b_dtype
+ self.accum_dtype = accum_dtype
+ self.a_transposed = a_transposed
+ self.b_transposed = b_transposed
+ self.e_transposed = e_transposed
+ # Hint Information
+ self.block_row_warps = block_row_warps
+ self.block_col_warps = block_col_warps
+ self.warp_row_tiles = warp_row_tiles
+ self.warp_col_tiles = warp_col_tiles
+ self.warp_k = warp_k
+ self.e_factor = self.E_FACTOR_MAP[self.a_dtype][self.e_dtype]
+ self._initialize_k_dim(a_dtype)
+ self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
+ self._initialize_micro_size(self.M_DIM, self.k_dim)
+ self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE)
+ self._initialize_mma_sp_prefix(self.k_dim)
+ self._initialize_is_m_first(is_m_first)
+
+ self.reduce_k = reduce_k
+ self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
+ self.num_elems_per_byte = num_elems_per_byte
+ self.thread_var = thread_var
+
+ if self.warp_rows == 0 or self.warp_cols == 0:
+ raise ValueError(
+ f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
+ )
+
+ def _initialize_k_dim(self, a_dtype="float16"):
+ if isinstance(a_dtype, str):
+ a_dtype = DataType(a_dtype)
+ # NOTE: k_dim here represents the logical shape of the MMA operation.
+ # When referring to the physical data movement, it should be divided by sparse_factor.
+ self.k_dim = 256 // a_dtype.bits * self.SPARSE_FACTOR
+
+ def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
+ self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR
+ self.local_size_e = (
+ m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype]
+ self.local_size_b = (n_dim * k_dim) // warp_size
+ self.local_size_out = (m_dim * n_dim) // warp_size
+
+ def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
+ self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
+ self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
+ self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
+
+ def _initialize_mma_sp_prefix(self, k_dim: int = 16):
+ if k_dim == 16:
+ # typically used for tfloat32
+ self.mma_prefix = "m16n8k16"
+ elif k_dim == 32:
+ # typically used for float16/bfloat16
+ self.mma_prefix = "m16n8k32"
+ elif k_dim == 64:
+ # typically used for int8/fp8
+ self.mma_prefix = "m16n8k64"
+ else:
+ raise ValueError("Unsupported k_dim")
+
+ def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
+ warp_row_tiles = self.warp_row_tiles
+ warp_col_tiles = self.warp_col_tiles
+ assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
+ assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
+ assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
+ assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
+
+ self.warp_rows = warp_row_tiles // m_dim
+
+ if warp_col_tiles % 16 == 0:
+ self.n_dim = 16
+ self.micro_size_y = 16
+ self.warp_cols = warp_col_tiles // 16
+ else:
+ # must be divisible by 8
+ self.n_dim = 8
+ self.micro_size_y = 8
+ self.warp_cols = warp_col_tiles // 8
+
+ self.micro_size_x = m_dim
+ # NOTE: k_dim here represents the logical shape of the MMA operation.
+ self.micro_size_k = k_dim
+
+ def _initialize_is_m_first(self, is_m_first: bool | None = False):
+ if is_m_first is not None:
+ self.is_m_first = is_m_first
+
+ def get_thread_binding(self):
+ if self.thread_var is None:
+ current_frame = T.KernelLaunchFrame.Current()
+ assert current_frame is not None, "Must be called in a T.Kernel Frame"
+ return current_frame.get_thread_binding()
+ else:
+ return self.thread_var
+
+ def get_store_index_map(self, inverse: bool = False) -> IndexMap:
+ warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
+ index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
+ if not inverse:
+ return index_map
+ inverse_index_map = index_map.inverse([warp_size, local_size_c])
+ return inverse_index_map
+
+ def extract_thread_binding(
+ self,
+ thread_id: PrimExpr,
+ is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
+ """
+ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
+ which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
+ Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
+ """
+ WARP_SIZE = self.WARP_SIZE
+ block_row_warps = self.block_row_warps
+ block_col_warps = self.block_col_warps
+
+ # if is_m_first is None, then use the default value
+ if is_m_first is None:
+ is_m_first = self.is_m_first
+
+ if is_m_first:
+ lane_id, warp_n, warp_m = (
+ thread_id % WARP_SIZE,
+ (thread_id // WARP_SIZE) % block_col_warps,
+ (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
+ )
+ return lane_id, warp_n, warp_m
+ else:
+ lane_id, warp_m, warp_n = (
+ thread_id % WARP_SIZE,
+ (thread_id // WARP_SIZE) % block_row_warps,
+ (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
+ )
+ return lane_id, warp_n, warp_m
+
+ def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0):
+ warp_row_tiles = self.warp_row_tiles
+ warp_rows = self.warp_rows
+ warp_k = self.warp_k
+ micro_size_x = self.micro_size_x
+ micro_size_k = self.micro_size_k
+ local_size_a = self.local_size_a
+ a_dtype = self.a_dtype
+ a_transposed = self.a_transposed
+ # ldmatrix cannot be used for int8 + trans case.
+ ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed)
+
+ def mma_load_layout(i, j):
+ return i, j
+
+ if not ldmatrix_available:
+ if DataType(a_dtype).bits == 8:
+ mma_load_layout = mma_sp_load_a_32x16_to_shared_16x64_layout
+ elif DataType(a_dtype).bits == 16:
+ mma_load_layout = mma_sp_load_a_32x8_to_shared_16x32_layout
+ elif DataType(a_dtype).bits == 32:
+ mma_load_layout = mma_sp_load_a_32x4_to_shared_16x16_layout
+ else:
+ raise ValueError(f"Unsupported dtype: {a_dtype}")
+
+ thread_binding = self.get_thread_binding()
+
+ @T.macro
+ def _warp_ldmatrix_a(
+ A_local_buf,
+ A_shared_buf,
+ ki,
+ thread_binding,
+ rk=0,
+ ):
+ stride = A_shared_buf.shape[-1]
+ tx, _, warp_m = self.extract_thread_binding(thread_binding)
+ trans = self.a_transposed
+
+ for i in T.serial(warp_rows):
+ # Assign A_shared_buf_elem
+ wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (
+ rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR
+ A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk]
+
+ if ldmatrix_available:
+ T.ptx_ldmatrix(
+ a_dtype,
+ T.bool(trans),
+ 4,
+ ".b16",
+ A_local_buf.data,
+ i * local_size_a,
+ T.address_of(A_shared_buf_elem),
+ get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
+ )
+ else:
+ for j in T.serial(local_size_a):
+ mi, mk = mma_load_layout(tx, j)
+ A_local_buf[i * local_size_a +
+ j] = A_shared_buf[wk + mk, wi +
+ mi] if a_transposed else A_shared_buf[wi + mi,
+ wk + mk]
+
+ return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
+
+ def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0):
+ warp_row_tiles = self.warp_row_tiles
+ warp_rows = self.warp_rows
+ warp_k = self.warp_k
+ micro_size_x = self.micro_size_x
+ micro_size_k = self.micro_size_k
+ local_size_e = self.local_size_e
+ a_dtype = self.a_dtype
+ e_dtype = self.e_dtype
+ trans = self.e_transposed
+ # ldmatrix cannot be used for int8 + trans case.
+ # include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h
+ ldmatrix_available = False # TODO: use ldmatrix when possible
+
+ def mma_load_layout(i, j):
+ return i, j
+
+ if not ldmatrix_available:
+ if DataType(e_dtype).bits == 8:
+ if DataType(a_dtype).bits == 8:
+ mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
+ elif DataType(a_dtype).bits == 16:
+ mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
+ elif DataType(a_dtype).bits == 32:
+ mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
+ else:
+ raise ValueError(f"Unsupported a_dtype for e_dtype 8bit: {a_dtype}")
+ elif DataType(e_dtype).bits == 16:
+ if DataType(a_dtype).bits == 8:
+ mma_load_layout = metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
+ elif DataType(a_dtype).bits == 16:
+ mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
+ elif DataType(a_dtype).bits == 32:
+ mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
+ else:
+ raise ValueError(f"Unsupported a_dtype for e_dtype 16bit: {a_dtype}")
+ elif DataType(e_dtype).bits == 32:
+ if DataType(a_dtype).bits == 8:
+ mma_load_layout = metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
+ else:
+ raise ValueError(f"Unsupported a_dtype for e_dtype 32bit: {a_dtype}")
+ else:
+ raise ValueError(f"Unsupported dtype: {e_dtype}")
+
+ thread_binding = self.get_thread_binding()
+
+ @T.macro
+ def _warp_ldmatrix_e(
+ E_local_buf,
+ E_shared_buf,
+ ki,
+ thread_binding,
+ rk=0,
+ ):
+ tx, _, warp_m = self.extract_thread_binding(thread_binding)
+ for i in T.serial(warp_rows):
+ # Assign E_shared_buf_elem
+ wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (
+ rk * warp_k + ki * micro_size_k) // self.e_factor
+ for j in T.serial(local_size_e):
+ mi, mk = mma_load_layout(tx, j)
+ E_local_buf[i * local_size_e +
+ j] = E_shared_buf[wk + mk,
+ wi + mi] if trans else E_shared_buf[wi + mi,
+ wk + mk]
+
+ return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk)
+
+ def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0):
+ warp_col_tiles = self.warp_col_tiles
+ warp_cols = self.warp_cols
+ warp_k = self.warp_k
+ micro_size_y = self.micro_size_y
+ micro_size_k = self.micro_size_k
+ local_size_b = self.local_size_b
+ b_dtype = self.b_dtype
+ b_transposed = self.b_transposed
+ thread_binding = self.get_thread_binding()
+ replicate_b = (self.n_dim == 16)
+ # ldmatrix cannot be used for int8 + trans case.
+ ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
+
+ def mma_load_layout(i, j):
+ return i, j
+
+ if not ldmatrix_available:
+ if DataType(b_dtype).bits == 8:
+ mma_load_layout = mma_sp_load_b_32x32_to_shared_16x64_layout
+ elif DataType(b_dtype).bits == 16:
+ mma_load_layout = mma_sp_load_b_32x16_to_shared_16x32_layout
+ elif DataType(b_dtype).bits == 32:
+ mma_load_layout = mma_sp_load_b_32x8_to_shared_16x16_layout
+ else:
+ raise ValueError(f"Unsupported dtype: {b_dtype}")
+
+ @T.macro
+ def _warp_ldmatrix_b(
+ B_local_buf,
+ B_shared_buf,
+ ki,
+ thread_binding,
+ rk=0,
+ ):
+ stride = B_shared_buf.shape[-1]
+ tx, warp_n, _ = self.extract_thread_binding(thread_binding)
+ trans = not b_transposed
+
+ for i in T.serial(warp_cols):
+ # Assign B_shared_elem
+ wi, wk = (
+ warp_n * warp_col_tiles + i * micro_size_y,
+ rk * warp_k + ki * micro_size_k,
+ )
+
+ if ldmatrix_available:
+ B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk,
+ wi]
+
+ if replicate_b:
+ T.ptx_ldmatrix(
+ b_dtype,
+ T.bool(trans),
+ 4,
+ ".b16",
+ B_local_buf.data,
+ i * local_size_b,
+ T.address_of(B_shared_buf_elem),
+ get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed),
+ )
+
+ T.ptx_ldmatrix(
+ b_dtype,
+ T.bool(trans),
+ 4,
+ ".b16",
+ B_local_buf.data,
+ i * local_size_b + lift(local_size_b) // 2,
+ T.address_of(B_shared_buf_elem),
+ get_ldmatrix_offset_b("B", tx,
+ lift(local_size_b) // 2, stride, b_dtype,
+ b_transposed),
+ )
+ else:
+ T.ptx_ldmatrix(
+ b_dtype,
+ T.bool(trans),
+ 4,
+ ".b16",
+ B_local_buf.data,
+ i * local_size_b,
+ T.address_of(B_shared_buf_elem),
+ get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed),
+ )
+
+ else:
+ # load 16x32 data from shared buffer to local buffer
+ # must be transposed.
+ for j in T.serial(local_size_b):
+ mi, mk = mma_load_layout(tx, j)
+ B_local_buf[i * local_size_b +
+ j] = B_shared_buf[wi + mi, wk +
+ mk] if b_transposed else B_shared_buf[wk + mk,
+ wi + mi]
+
+ return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
+
+ def mma_sp(self,
+ A_local_buf: Buffer,
+ E_local_buf: Buffer,
+ B_local_buf: Buffer,
+ C_local_buf: Buffer,
+ k_inner: PrimExpr = 0):
+ warp_rows = self.warp_rows
+ warp_cols = self.warp_cols
+ local_size_a = self.local_size_a
+ local_size_e = self.local_size_e
+ local_size_b = self.local_size_b
+ local_size_out = self.local_size_out
+ a_dtype_abbrv = self.a_dtype_abbrv
+ b_dtype_abbrv = self.b_dtype_abbrv
+ accum_dtype = self.accum_dtype
+ accum_dtype_abbrv = self.accum_dtype_abbrv
+ mma_prefix = self.mma_prefix
+ replicate_b = (self.n_dim == 16)
+
+ a_is_fragment = is_fragment(A_local_buf)
+ e_is_fragment = is_fragment(E_local_buf)
+ b_is_fragment = is_fragment(B_local_buf)
+ assert not e_is_fragment, f"currently E_local_buf must be a local allocation, found {E_local_buf.scope()}"
+ a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
+ e_local_stride: PrimExpr = k_inner * warp_rows * local_size_e if e_is_fragment else 0
+ b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
+
+ @T.macro
+ def _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf):
+ for i, j in T.grid(warp_rows, warp_cols):
+ T.ptx_mma_sp(
+ accum_dtype,
+ mma_prefix,
+ "row",
+ "col",
+ a_dtype_abbrv,
+ b_dtype_abbrv,
+ accum_dtype_abbrv,
+ A_local_buf.data,
+ a_local_stride + i * local_size_a,
+ B_local_buf.data,
+ b_local_stride + j * local_size_b,
+ C_local_buf.data,
+ i * warp_cols * local_size_out + j * local_size_out,
+ E_local_buf.data, # metadata
+ e_local_stride + i * local_size_e, # metadata offset
+ self.SPARSE_SELECTOR, # sparse_selector
+ T.bool(False), # saturate
+ )
+ if replicate_b:
+ T.ptx_mma_sp(
+ accum_dtype,
+ mma_prefix,
+ "row",
+ "col",
+ a_dtype_abbrv,
+ b_dtype_abbrv,
+ accum_dtype_abbrv,
+ A_local_buf.data,
+ a_local_stride + i * local_size_a,
+ B_local_buf.data,
+ b_local_stride + j * local_size_b + lift(local_size_b) // 2,
+ C_local_buf.data,
+ i * warp_cols * local_size_out + j * local_size_out +
+ lift(local_size_out) // 2,
+ E_local_buf.data, # metadata
+ e_local_stride + i * local_size_e, # metadata offset
+ self.SPARSE_SELECTOR, # sparse_selector
+ T.bool(False), # saturate
+ )
+
+ return _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf)
+
+ def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
+ block_row_warps = self.block_row_warps
+ block_col_warps = self.block_col_warps
+ warp_rows = self.warp_rows
+ warp_cols = self.warp_cols
+ local_size_out = self.local_size_out
+
+ is_global = pid_m is not None and pid_n is not None
+ BLOCK_M = block_row_warps * warp_rows
+ BLOCK_N = block_col_warps * warp_cols
+ M_DIM, n_dim = self.M_DIM, self.n_dim
+ C_buf_dims = len(C_buf.shape)
+ assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
+
+ thread_binding = self.get_thread_binding()
+
+ # STS
+ # MMA Store must be in simulated instead of TVM Intrins
+ # As TVM Intrins is like a hack that the threadIdx.x should be always
+ # equal to the warp_size
+ @T.macro
+ def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
+ tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+ for i, j in T.grid(warp_rows, warp_cols):
+ for local_id_o in T.serial(local_size_out // 2):
+ for local_id_i in T.vectorized(2):
+ local_id = local_id_o * 2 + local_id_i
+ row, col = T.meta_var(mma_store_index_map(tx, local_id))
+ if C_buf_dims == 2:
+ C_buf[(warp_m * warp_rows + i) * M_DIM + row,
+ (warp_n * warp_cols + j) * n_dim +
+ col] = C_local_buf[i * (warp_cols * local_size_out) +
+ j * local_size_out + local_id]
+ else:
+ C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
+ col] = C_local_buf[i * (warp_cols * local_size_out) +
+ j * local_size_out + local_id]
+
+ @T.macro
+ def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
+ tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+ for i, j in T.grid(warp_rows, warp_cols):
+ for local_id_o in T.serial(local_size_out // 2):
+ for local_id_i in T.vectorized(2):
+ local_id = local_id_o * 2 + local_id_i
+ row, col = T.meta_var(mma_store_index_map(tx, local_id))
+ C_buf[
+ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
+ (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col,
+ ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
+ local_id]
+
+ return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
+ if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding))
+
+ def make_mma_load_layout(self,
+ local_buf: Buffer,
+ matrix: Literal["A", "B"] = "A") -> T.Fragment:
+ """
+ Create a layout function for storing MMA results into a fragment buffer.
+ This layout is used in conjunction with `inverse_mma_store_layout` to
+ map fragment indices to threads and local indices.
+
+ Parameters
+ ----------
+ local_buf : tir.Buffer
+ The local buffer representing a fragment of a matrix.
+
+ Returns
+ -------
+ T.Fragment
+ A fragment object that describes how threads and indices
+ in `local_buf` are laid out.
+
+ Raises
+ ------
+ AssertionError
+ If `local_buf` is not detected to be a fragment buffer.
+ """
+ from tilelang.utils import is_fragment
+ assert matrix in ["A", "B"], "matrix should be either A or B"
+ matrix_is_a: bool = matrix == "A"
+ matrix_is_b: bool = matrix == "B"
+ dtype = self.a_dtype if matrix_is_a else self.b_dtype
+ dtype_bits = DataType(dtype).bits
+ transposed = self.a_transposed if matrix_is_a else self.b_transposed
+
+ # s represents spatial axis
+ # r represents reduction axis
+ # sr represents the two dims are spatial + reduction
+ # rs represents the two dims are reduction + spatial
+ # sr also can represent a non-transposed basic layout
+ # then rs also can represent a transposed basic layout
+ transform_func_sr_a: Callable = None
+ transform_func_sr_b: Callable = None
+ if dtype_bits == 32:
+ transform_func_sr_a = shared_16x16_to_mma_sp_layout_sr_a
+ transform_func_sr_b = shared_16x16_to_mma_sp_layout_sr_b
+ elif dtype_bits == 16:
+ transform_func_sr_a = shared_16x32_to_mma_sp_layout_sr_a
+ transform_func_sr_b = shared_16x32_to_mma_sp_layout_sr_b
+ elif dtype_bits == 8:
+ transform_func_sr_a = shared_16x64_to_mma_sp_layout_sr_a
+ transform_func_sr_b = shared_16x64_to_mma_sp_layout_sr_b
+ else:
+ raise ValueError(f"Unsupported dtype {dtype}")
+
+ is_sr_conditions = [False]
+ is_sr_conditions.append(matrix_is_a and not transposed)
+ is_sr_conditions.append(matrix_is_b and transposed)
+ is_sr_axis_order = any(is_sr_conditions)
+
+ # the layout of mma.sync is row.col.
+ # so the b matrix expected a transposed basic layout
+ transform_func: Callable = None
+ if matrix_is_a:
+ transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
+ j, i)
+ elif matrix_is_b:
+ transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
+ j, i)
+ else:
+ raise ValueError(f"Unsupported matrix {matrix}")
+
+ assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
+
+ if matrix_is_a:
+ micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
+ else:
+ micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
+
+ block_row_warps, block_col_warps = (
+ self.block_row_warps,
+ self.block_col_warps,
+ )
+
+ inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
+
+ def forward_thread(i: int, j: int) -> int:
+ """
+ Given the row index `i` and column index `j` in the fragment,
+ """
+ lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
+ return lane_id
+
+ def forward_index(i: int, j: int) -> int:
+ """
+ Given the row index `i` and column index `j` in the fragment,
+ """
+ _, local_id = inverse_mma_load_layout.map_indices([i, j])
+ return local_id
+
+ base_fragment = T.Fragment(
+ [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] if is_sr_axis_order
+ else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s],
+ forward_thread_fn=forward_thread,
+ forward_index_fn=forward_index,
+ )
+
+ warp_rows, warp_cols = self.warp_rows, self.warp_cols
+ chunk = self.warp_k
+
+ warp_s = warp_rows if matrix_is_a else warp_cols
+ warp_r = chunk // micro_size_r
+ block_s = block_row_warps if matrix_is_a else block_col_warps
+ replicate = block_col_warps if matrix_is_a else block_row_warps
+
+ if is_sr_axis_order:
+ warp_fragment = base_fragment.repeat([warp_s, warp_r],
+ repeat_on_thread=False,
+ lower_dim_first=False)
+ if matrix_is_a:
+ block_fragment = warp_fragment.repeat([block_s, 1],
+ repeat_on_thread=True,
+ lower_dim_first=True).replicate(replicate)
+ elif matrix_is_b:
+ block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
+ repeat_on_thread=True,
+ lower_dim_first=True)
+ else:
+ raise ValueError(f"Unsupported matrix type {matrix}")
+ else:
+ warp_fragment = base_fragment.repeat([warp_r, warp_s],
+ repeat_on_thread=False,
+ lower_dim_first=True)
+ if matrix_is_a:
+ block_fragment = warp_fragment.repeat([1, block_s],
+ repeat_on_thread=True,
+ lower_dim_first=True).replicate(replicate)
+ elif matrix_is_b:
+ block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
+ repeat_on_thread=True,
+ lower_dim_first=True)
+ else:
+ raise ValueError(f"Unsupported matrix type {matrix}")
+
+ return block_fragment
+
+ def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
+ """
+ Create a layout function for storing MMA results into a fragment buffer.
+ This layout is used in conjunction with `inverse_mma_store_layout` to
+ map fragment indices to threads and local indices.
+
+ Parameters
+ ----------
+ local_buf : tir.Buffer
+ The local buffer representing a fragment of a matrix.
+
+ Returns
+ -------
+ T.Fragment
+ A fragment object that describes how threads and indices
+ in `local_buf` are laid out.
+
+ Raises
+ ------
+ AssertionError
+ If `local_buf` is not detected to be a fragment buffer.
+ """
+ from tilelang.utils import is_fragment
+
+ shape = local_buf.shape
+ inverse_mma_store_layout = self.get_store_index_map(inverse=True)
+ assert is_fragment(local_buf), "local_buf must be a fragment"
+ micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
+ local_size_out = self.local_size_out
+ block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
+ warp_rows, warp_cols = self.warp_rows, self.warp_cols
+ warp_size = self.WARP_SIZE
+ is_m_first = self.is_m_first
+
+ def forward_thread(i: int, j: int) -> int:
+ """
+ Given the row index `i` and column index `j` in the fragment,
+ map them to a thread index according to `inverse_mma_store_layout`.
+ """
+ # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
+ # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
+ block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
+ # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
+ mma_i, mma_j = i % micro_size_x, j % micro_size_y
+ lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
+ if is_m_first:
+ thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
+ else:
+ thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
+ return thread_id
+
+ def forward_index(i: int, j: int) -> int:
+ """
+ Given the row index `i` and column index `j` in the fragment,
+ map them to a local index in a single thread according
+ to `inverse_mma_store_layout`.
+ """
+ # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
+ # the upper bounds of warp_i and warp_j are warp_rows and warp_cols
+ warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
+ # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
+ mma_i, mma_j = i % micro_size_x, j % micro_size_y
+ _, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
+ return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
+
+ return T.Fragment(
+ shape,
+ forward_thread_fn=forward_thread,
+ forward_index_fn=forward_index,
+ )
diff --git a/tilelang/ir.py b/tilelang/ir.py
index cccf97e0a..08d4e96cd 100644
--- a/tilelang/ir.py
+++ b/tilelang/ir.py
@@ -39,6 +39,19 @@ def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target
return self.m_warp, self.n_warp
+@tvm_ffi.register_object("tl.GemmSPWarpPolicy")
+class GemmSPWarpPolicy(Node, Scriptable):
+ policy_type: int
+ m_warp: int
+ n_warp: int
+
+ def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
+ is_wgmma: bool, bits: int):
+ _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
+ is_wgmma, bits)
+ return self.m_warp, self.n_warp
+
+
@tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable):
...
diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py
index 75d8d0b4f..1f560a446 100644
--- a/tilelang/language/__init__.py
+++ b/tilelang/language/__init__.py
@@ -24,7 +24,15 @@
LocalBuffer, # noqa: F401
Ref, # noqa: F401
)
-from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
+from .loop import (
+ Parallel, # noqa: F401
+ Persistent, # noqa: F401
+ Pipelined, # noqa: F401
+ serial, # noqa: F401
+ unroll, # noqa: F401
+ Serial, # noqa: F401
+ Unroll, # noqa: F401
+)
from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401
from .kernel import (
@@ -51,7 +59,7 @@
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
-from .experimental.gemm_sp import gemm_sp # noqa: F401
+from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401
from .fill import fill, clear # noqa: F401
from .reduce import (
reduce, # noqa: F401
diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py
index 56f87473f..07e45bbc8 100644
--- a/tilelang/language/atomic.py
+++ b/tilelang/language/atomic.py
@@ -212,9 +212,9 @@ def get_extent(data):
"return_prev is not supported for tile-region-based atomic operations")
if memory_order is None:
- return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, 0)
+ return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0)
else:
- return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma,
+ return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma,
_MEMORY_ORDER_ID_MAP[memory_order])
diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py
index 965919fd4..cabc4a3e4 100644
--- a/tilelang/language/copy.py
+++ b/tilelang/language/copy.py
@@ -90,7 +90,7 @@ def get_extent(data):
eviction_policy = 0
else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
- return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width,
+ return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width,
disable_tma, eviction_policy)
@@ -124,5 +124,5 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
img_region = to_buffer_region(img, access_type="r")
col_region = to_buffer_region(col, access_type="w")
- return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img_region, col_region,
+ return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region,
nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy)
diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py
index 7cc3d736d..4a20f3fb6 100644
--- a/tilelang/language/experimental/gemm_sp.py
+++ b/tilelang/language/experimental/gemm_sp.py
@@ -3,7 +3,15 @@
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
-from tilelang.utils.language import to_buffer_region
+from tilelang.utils.language import (
+ to_buffer_region,
+ retrieve_shape,
+ retrieve_stride,
+ retrieve_offset,
+ prim_expr_equal,
+)
+from tilelang.language.utils import (
+ buffer_region_to_tile_region,)
def gemm_sp(
@@ -70,7 +78,7 @@ def legalize_arguments(arg: tir.Buffer | tir.Var):
C_arg = to_buffer_region(C, access_type="rw")
return tir.call_intrin(
"handle",
- tir.op.Op.get("tl.gemm_sp"),
+ tir.op.Op.get("tl.tileop.gemm_sp"),
A_arg,
E_arg,
B_arg,
@@ -85,3 +93,128 @@ def legalize_arguments(arg: tir.Buffer | tir.Var):
k_pack,
wg_wait,
)
+
+
+# experimental currently, for fast compilation
+def gemm_sp_v2(
+ A_sparse: tir.Buffer | tir.Var,
+ E: tir.Buffer | tir.Var,
+ B: tir.Buffer | tir.Var,
+ C: tir.Buffer | tir.Var,
+ transpose_A: bool = False,
+ transpose_B: bool = False,
+ transpose_E: bool = False,
+ policy: GemmWarpPolicy = GemmWarpPolicy.Square,
+ clear_accum: bool = False,
+ k_pack: int = 1,
+ wg_wait: int = 0,
+):
+ """Perform a General Matrix Multiplication (GEMM) operation.
+
+ This function computes C = A @ B where A and B can optionally be transposed.
+ The operation supports various warp policies and accumulation modes.
+
+ Args:
+ A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements
+ E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E
+ B (Union[tir.Buffer, tir.Var]): Second input matrix
+ C (Union[tir.Buffer, tir.Var]): Output matrix for results
+ transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
+ transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
+ policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
+ clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
+ k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
+ wg_wait (int, optional): Warp group wait count. Defaults to 0.
+
+ Returns:
+ tir.Call: A handle to the GEMM operation
+
+ Raises:
+ AssertionError: If the K dimensions of matrices A and B don't match
+ """
+
+ def legalize_arguments(arg: tir.Buffer | tir.Var):
+ """Convert let-bound variables to their corresponding buffers.
+
+ Args:
+ arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
+
+ Returns:
+ Union[tir.Buffer, tir.Var]: The legalized argument
+ """
+ if isinstance(arg, tir.Var) and T.has_let_value(arg):
+ return T.get_let_value(arg).buffer
+ return arg
+
+ A_sparse = legalize_arguments(A_sparse)
+ E = legalize_arguments(E)
+ B = legalize_arguments(B)
+ C = legalize_arguments(C)
+
+ A_region = to_buffer_region(A_sparse)
+ E_region = to_buffer_region(E)
+ B_region = to_buffer_region(B)
+ C_region = to_buffer_region(C)
+
+ A_shape = retrieve_shape(A_sparse)
+ E_shape = retrieve_shape(E) # nolint: F841
+ B_shape = retrieve_shape(B)
+ C_shape = retrieve_shape(C)
+
+ A_stride = retrieve_stride(A_sparse)
+ B_stride = retrieve_stride(B)
+
+ assert len(C_shape) == 2, "current only support C as a 2D tensor"
+ assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
+ assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
+ if len(A_shape) > 2:
+ for i in range(len(A_shape) - 2):
+ assert A_shape[i] == 1, \
+ "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
+ if len(B_shape) > 2:
+ for i in range(len(B_shape) - 2):
+ assert B_shape[i] == 1, \
+ "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
+
+ M, N = C_shape
+ K = 2 * (A_shape[-2] if transpose_A else A_shape[-1])
+ K_B = B_shape[-1] if transpose_B else B_shape[-2]
+ assert prim_expr_equal(
+ K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
+
+ stride_a = A_stride[-2]
+ stride_b = B_stride[-2]
+
+ A_offset = retrieve_offset(A_sparse)
+ B_offset = retrieve_offset(B)
+ assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
+ assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
+ offset_a = A_offset[-1]
+ offset_b = B_offset[-1]
+
+ A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
+ E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape])
+ B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
+ C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
+ return tir.call_intrin(
+ "handle",
+ tir.op.Op.get("tl.tileop.gemm_sp_py"),
+ A_arg,
+ E_arg,
+ B_arg,
+ C_arg,
+ transpose_A,
+ transpose_B,
+ transpose_E,
+ M,
+ N,
+ K,
+ policy,
+ clear_accum,
+ stride_a,
+ stride_b,
+ offset_a,
+ offset_b,
+ k_pack,
+ wg_wait,
+ )
diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py
index fbbcf1b63..b23733377 100644
--- a/tilelang/language/fill.py
+++ b/tilelang/language/fill.py
@@ -32,7 +32,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
extents = [tir.IntImm("int32", 1) for _ in buffer.indices]
else:
extents = []
- return tir.call_intrin("handle", tir.op.Op.get("tl.fill"),
+ return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"),
to_buffer_region(buffer, access_type="w", extents=extents), value)
diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py
index 2bfd3a0cf..db8e04aba 100644
--- a/tilelang/language/gemm.py
+++ b/tilelang/language/gemm.py
@@ -116,7 +116,7 @@ def gemm_v1(
):
"""GEMM v1: use op tl.gemm."""
return _gemm_impl(
- "tl.gemm",
+ "tl.tileop.gemm",
A,
B,
C,
@@ -145,7 +145,7 @@ def gemm_v2(
):
"""GEMM v2: use op tl.gemm_py."""
return _gemm_impl(
- "tl.gemm_py",
+ "tl.tileop.gemm_py",
A,
B,
C,
diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py
index 4f8d5c307..3478b6cc1 100644
--- a/tilelang/language/loop.py
+++ b/tilelang/language/loop.py
@@ -4,8 +4,9 @@
from tvm import tir
from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
-from .v2.builder import SerialForWithStep
+from .v2.builder import SerialForWithStep, UnrollForWithStep
from tilelang import _ffi_api
+from tvm.script.ir_builder.tir import frame
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
@@ -97,7 +98,7 @@ def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
- annotations: dict[str, Any] | None = None):
+ annotations: dict[str, Any] | None = None) -> frame.ForFrame:
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
@@ -108,3 +109,70 @@ def serial(start: tir.PrimExpr,
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
+
+
+def unroll(start: tir.PrimExpr,
+ stop: tir.PrimExpr | None = None,
+ step: tir.PrimExpr | None = None,
+ *,
+ explicit: bool = False,
+ unroll_factor: int | None = None,
+ annotations: dict[str, Any] | None = None) -> frame.ForFrame:
+ """The unrolled For statement.
+
+ Parameters
+ ----------
+ start : PrimExpr
+ The minimum value of iteration.
+
+ stop : PrimExpr
+ The maximum value of iteration.
+
+ step : PrimExpr
+ The step size of the iteration.
+
+ explicit : bool
+ Whether to explicitly unroll the loop.
+
+ unroll_factor : int
+ The unroll factor of the loop.
+
+ annotations : Dict[str, Any]
+ The optional annotations of the For statement.
+
+ Returns
+ -------
+ res : frame.ForFrame
+ The ForFrame.
+ """
+
+ step_is_one = False
+ if stop is None:
+ stop = start
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
+
+ # Ensure annotations has {"pragma_unroll_explicit": True} by default
+ if annotations is None:
+ annotations = {"pragma_unroll_explicit": explicit}
+ else:
+ # Add "pragma_unroll_explicit": True if not already present
+ annotations = dict(annotations)
+ annotations.setdefault("pragma_unroll_explicit", explicit)
+
+ if unroll_factor is not None:
+ # check pragma_unroll_explicit must be False
+ if annotations.get("pragma_unroll_explicit", True):
+ raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None")
+ annotations.update({"pragma_unroll_factor": unroll_factor})
+
+ if step is None or step_is_one:
+ return tb_tir.unroll(start, stop, annotations=annotations)
+ else:
+ return UnrollForWithStep(start, stop, step, annotations=annotations)
+
+
+Serial = serial
+Unroll = unroll
diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py
index 3c4d8187b..fb84b6d78 100644
--- a/tilelang/language/reduce.py
+++ b/tilelang/language/reduce.py
@@ -13,6 +13,9 @@ def _legalize_dim(buffer: tir.Buffer, dim: int):
return dim
+_REDUCE_OP_KEY = "tl.tileop.reduce"
+
+
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
"""Perform a reduction operation on a buffer along a specified dimension.
@@ -50,7 +53,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
- tir.op.Op.get("tl.reduce"),
+ tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(red_frag_in, access_type="r"),
to_buffer_region(red_frag_out, access_type="w"),
reduce_type,
@@ -65,7 +68,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
- tir.op.Op.get("tl.reduce"),
+ tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(red_frag_in, access_type="r"),
to_buffer_region(out, access_type="w"),
reduce_type,
@@ -78,7 +81,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int
tir.call_intrin(
"handle",
- tir.op.Op.get("tl.reduce"),
+ tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(buffer, access_type="r"),
to_buffer_region(red_frag_out, access_type="w"),
reduce_type,
@@ -89,7 +92,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int
elif is_fragment(buffer) and is_fragment(out):
tir.call_intrin(
"handle",
- tir.op.Op.get("tl.reduce"),
+ tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(buffer, access_type="r"),
to_buffer_region(out, access_type="w"),
reduce_type,
@@ -245,7 +248,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
copy(src, cumsum_smem)
tir.call_intrin(
"handle",
- tir.op.Op.get("tl.cumsum"),
+ tir.op.Op.get("tl.tileop.cumsum"),
to_buffer_region(cumsum_smem, access_type="r"),
to_buffer_region(cumsum_smem, access_type="w"),
dim,
@@ -299,7 +302,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
return cumsum_fragment(src, dst, dim, reverse)
return tir.call_intrin(
"handle",
- tir.op.Op.get("tl.cumsum"),
+ tir.op.Op.get("tl.tileop.cumsum"),
to_buffer_region(src, access_type="r"),
to_buffer_region(dst, access_type="w"),
dim,
@@ -309,7 +312,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
def finalize_reducer(reducer: tir.Buffer):
"""
- Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic.
+ Finalize a reducer buffer by emitting the `tl.tileop.finalize_reducer` intrinsic.
This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer.
The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR.
@@ -322,7 +325,7 @@ def finalize_reducer(reducer: tir.Buffer):
"""
return tir.call_intrin(
"handle",
- tir.op.Op.get("tl.finalize_reducer"),
+ tir.op.Op.get("tl.tileop.finalize_reducer"),
to_buffer_region(reducer, access_type="w"),
)
diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py
index 75fea4c09..136bc0bac 100644
--- a/tilelang/language/utils.py
+++ b/tilelang/language/utils.py
@@ -7,7 +7,7 @@
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""Create a tl.region call for a BufferLoad and extents."""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
- return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args)
+ return T.call_intrin("handle", op.Op.get("tl.tileop.region"), buffer, access_type, *args)
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]):
diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py
index aea425adc..68c109133 100644
--- a/tilelang/language/v2/builder.py
+++ b/tilelang/language/v2/builder.py
@@ -112,6 +112,11 @@ class SerialForWithStep:
annotations: dict[str, Any] | None = None
+@dataclass
+class UnrollForWithStep(SerialForWithStep):
+ ...
+
+
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
@@ -270,7 +275,7 @@ def eval(self, val: Any):
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
- if isinstance(it, SerialForWithStep):
+ if isinstance(it, (SerialForWithStep, UnrollForWithStep)):
# Validate and compute the trip count before constructing the frame
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
@@ -285,7 +290,14 @@ def ctx_for(self, it):
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_stop = tir.ceildiv(it.stop - it.start, it.step)
- real_frame = tir.serial(real_stop, annotations=it.annotations)
+ if isinstance(it, UnrollForWithStep):
+ real_frame = tir.unroll(real_stop, annotations=it.annotations)
+ elif isinstance(it, SerialForWithStep):
+ real_frame = tir.serial(real_stop, annotations=it.annotations)
+ else:
+ raise TypeError(
+ f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
+ "range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding")
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py
index ee513257f..777802d2c 100644
--- a/tilelang/layout/__init__.py
+++ b/tilelang/layout/__init__.py
@@ -13,4 +13,4 @@
make_quarter_bank_swizzled_layout, # noqa: F401
make_linear_layout, # noqa: F401
)
-from .gemm_sp import make_metadata_layout # noqa: F401
+from .gemm_sp import make_cutlass_metadata_layout # noqa: F401
diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py
index eaaa178f5..e5d190292 100644
--- a/tilelang/layout/gemm_sp.py
+++ b/tilelang/layout/gemm_sp.py
@@ -17,7 +17,7 @@ def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]:
return res
-def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int):
+def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int):
"""Make a layout of metadata that is compatible with cutlass sm90 compression kernel. Note that layout atom is the same for smem and gmem.
Args:
@@ -30,7 +30,7 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
- if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8"]:
+ if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8_e4m3", "float8_e5m2"]:
raise NotImplementedError(f"Unsupported dtype: {mma_dtype}")
if buffer.dtype not in ["uint8", "int8"]:
@@ -41,7 +41,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
"bfloat16": 16,
"float32": 32,
"int8": 8,
- "float8": 8,
+ "float8_e4m3": 8,
+ "float8_e5m2": 8,
}
# ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117
@@ -75,8 +76,8 @@ def gen_stride(shape_ik, order):
shape_i, shape_k = shape_ik[:3], shape_ik[3:]
stride_i, stride_k = stride_ik[:3], stride_ik[3:]
elif bits_map[mma_dtype] == 8:
- shape_i, shape_k = [64], [BlockK]
- stride_i, stride_k = [BlockK], [1]
+ shape_i, shape_k = [64], [block_k // 8]
+ stride_i, stride_k = [block_k // 8], [1]
else:
raise NotImplementedError(f"Unknown mma type {mma_dtype}")
@@ -103,54 +104,48 @@ def transform(i: int, k: int) -> int:
return T.Layout(shape, transform)
-def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str):
+def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
"""Make a layout of metadata that is compatible with cutlass sm8x compression kernel. Note that layout atom is the same for smem and gmem.
-
+ ref: https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/sparse/_semi_structured_conversions.py#L5
Args:
buffer: metadata buffer shape, for sm80 it should be a 16bit type
"""
- # ref: https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h#L651
- # https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/matrix.h#L405
- # https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/warp/mma_sparse_tensor_op.h#L172
-
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
- if mma_dtype in ["float8", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]:
+ if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"
+ ] and buffer.dtype not in ["uint32", "int32"]:
raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
- kInterleaved = 2
- stride = buffer.shape[0] * kInterleaved
+ m, k = buffer.shape
+ group = 32 if buffer.dtype.bits == 16 else 16
+ interweave = 4 if buffer.dtype.bits == 16 else 2
def ColumnMajorInterleaved(i: int, j: int) -> int:
- column_major = j // kInterleaved
- column_minor = j % kInterleaved
- return column_major * stride + i * kInterleaved + column_minor
+ i = i // group * group + (i % 8) * interweave + (i % group) // 8
+ topright = (1 - (i % 2)) & (j % 2)
+ bottomleft = (i % 2) & (1 - (j % 2))
+ i += topright - bottomleft
+ j -= topright - bottomleft
+ offset = (j // 2) * m * 2 + i * 2 + (j % 2)
+ return offset // k, offset % k
return T.Layout(buffer.shape, ColumnMajorInterleaved)
-def make_metadata_layout(buffer: tvm.tir.Buffer,
- mma_dtype: str = "float16",
- backend: str = 'cutlass',
- arch: str | None = None,
- **extra_args):
+def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer,
+ mma_dtype: str = "float16",
+ arch: str | None = None,
+ **extra_args):
if arch is None:
arch = nvcc.get_target_compute_version()
compute_version = nvcc.parse_compute_version(arch)
if compute_version >= (9, 0):
- if backend == 'cutlass':
- return _make_metadata_layout_sm90_cutlass(
- buffer=buffer, mma_dtype=mma_dtype, **extra_args)
- else:
- raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
+ return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args)
elif compute_version >= (8, 0):
- if backend == 'cutlass':
- return _make_metadata_layout_sm8x_cutlass(buffer=buffer, mma_dtype=mma_dtype)
- else:
- raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
+ return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype)
else:
raise NotImplementedError(f"Unsupported architecture: {arch}")
diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py
index 5af1fc2bf..4750fa7d5 100644
--- a/tilelang/profiler/__init__.py
+++ b/tilelang/profiler/__init__.py
@@ -10,6 +10,7 @@
get_tensor_supply,
TensorSupplyType,
torch_assert_close,
+ is_float8_dtype,
)
from tilelang.engine.param import KernelParam
from tilelang.jit.adapter import BaseKernelAdapter
@@ -125,17 +126,9 @@ def assert_allclose(
if lhs is not None and rhs is not None:
# in case of numsplit template, the ref output may be None
# which means the value is invalid, so we skip the comparison
- def is_float8(tensor: torch.Tensor) -> bool:
- return tensor.dtype in {
- torch.float8_e5m2,
- torch.float8_e5m2fnuz,
- torch.float8_e4m3fn,
- torch.float8_e4m3fnuz,
- }
-
torch_assert_close(
- lhs if not is_float8(lhs) else lhs.to(torch.float32),
- rhs if not is_float8(rhs) else rhs.to(torch.float32),
+ lhs if not is_float8_dtype(lhs.dtype) else lhs.to(torch.float32),
+ rhs if not is_float8_dtype(rhs.dtype) else rhs.to(torch.float32),
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
diff --git a/tilelang/tileop/__init__.py b/tilelang/tileop/__init__.py
index 5656494fe..a99cbd878 100644
--- a/tilelang/tileop/__init__.py
+++ b/tilelang/tileop/__init__.py
@@ -1 +1,2 @@
from .gemm import GemmPy # noqa: F401
+from .gemm_sp import GemmSPPy # noqa: F401
diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py
index 4c6762450..90960904f 100644
--- a/tilelang/tileop/gemm/__init__.py
+++ b/tilelang/tileop/gemm/__init__.py
@@ -3,6 +3,7 @@
from tvm import tir
from tvm.target import Target
from tvm.ir.base import Node
+from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
@@ -16,13 +17,14 @@
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
-def gemm_py_infer_layout(gemm_py, target, thread_bounds):
+def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums)
@tvm_ffi.register_global_func("tl.gemm_py.lower")
-def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
+def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range,
+ thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
return stmt
diff --git a/tilelang/tileop/gemm_sp/__init__.py b/tilelang/tileop/gemm_sp/__init__.py
new file mode 100644
index 000000000..fdac694ce
--- /dev/null
+++ b/tilelang/tileop/gemm_sp/__init__.py
@@ -0,0 +1,69 @@
+from tilelang import tvm as tvm
+from tvm import tir
+from tilelang.utils.target import (
+ target_is_cuda,)
+from tvm.target import Target
+from tvm.ir.base import Node
+from tvm.ir import Range
+from tvm.runtime import Scriptable
+import tvm_ffi
+from tilelang.ir import GemmWarpPolicy
+from .gemm_sp_mma import GemmSPMMA
+
+
+@tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout")
+def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
+ thread_nums = thread_bounds.extent
+ return gemm_sp_py.infer_layout(target, thread_nums)
+
+
+@tvm_ffi.register_global_func("tl.gemm_sp_py.lower")
+def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range,
+ thread_var: tir.Var):
+ thread_nums = thread_bounds.extent
+ stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
+ return stmt
+
+
+@tvm_ffi.register_object("tl.GemmSPPy")
+class GemmSPPy(Node, Scriptable):
+ A: tir.Buffer
+ E: tir.Buffer
+ B: tir.Buffer
+ C: tir.Buffer
+
+ APtr: tir.PrimExpr
+ EPtr: tir.PrimExpr
+ BPtr: tir.PrimExpr
+ CPtr: tir.PrimExpr
+
+ M: int
+ N: int
+ K: int
+
+ trans_A: bool
+ trans_B: bool
+
+ stride_A: int
+ stride_B: int
+ offset_A: int
+ offset_B: int
+ clear_accum: bool
+ k_pack: int
+ wg_wait: int
+ policy: GemmWarpPolicy
+
+ def infer_layout(self, target: Target, thread_nums: int):
+ if target_is_cuda(target):
+ # TODO(lei): Support more cuda architectures, now mma only
+ return GemmSPMMA(self).infer_layout(target, thread_nums)
+ else:
+ raise ValueError(f"Unsupported target: {target}")
+
+ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
+ if target_is_cuda(target):
+ # TODO(lei): Support more cuda architectures, now mma only
+ # Now only implement ssr layout
+ return GemmSPMMA(self).lower(target, thread_nums, thread_var)
+ else:
+ raise ValueError(f"Unsupported target: {target}")
diff --git a/tilelang/tileop/gemm_sp/gemm_sp_base.py b/tilelang/tileop/gemm_sp/gemm_sp_base.py
new file mode 100644
index 000000000..51c6786b4
--- /dev/null
+++ b/tilelang/tileop/gemm_sp/gemm_sp_base.py
@@ -0,0 +1,131 @@
+from dataclasses import dataclass
+from tilelang import tvm as tvm
+from tvm.target import Target
+from tvm import tir
+from tilelang.utils.language import is_shared, is_fragment
+from tilelang.ir import GemmWarpPolicy
+from tvm.ir.base import Node
+
+
+@dataclass
+class GemmSPBase:
+ gemm_sp_node: Node
+
+ def infer_layout(self, target: Target, thread_nums: int):
+ raise NotImplementedError("infer_layout is not implemented")
+
+ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
+ raise NotImplementedError("lower is not implemented")
+
+ def is_gemm_ss(self) -> bool:
+ return is_shared(self.A) and is_shared(self.B)
+
+ def is_gemm_sr(self) -> bool:
+ return is_shared(self.A) and is_fragment(self.B)
+
+ def is_gemm_rs(self) -> bool:
+ return is_fragment(self.A) and is_shared(self.B)
+
+ def is_gemm_rr(self) -> bool:
+ return is_fragment(self.A) and is_fragment(self.B)
+
+ @property
+ def M(self) -> int:
+ return self.gemm_sp_node.M
+
+ @property
+ def N(self) -> int:
+ return self.gemm_sp_node.N
+
+ @property
+ def K(self) -> int:
+ return self.gemm_sp_node.K
+
+ @property
+ def trans_A(self) -> bool:
+ return self.gemm_sp_node.trans_A
+
+ @property
+ def trans_B(self) -> bool:
+ return self.gemm_sp_node.trans_B
+
+ @property
+ def trans_E(self) -> bool:
+ return self.gemm_sp_node.trans_E
+
+ @property
+ def e_dtype(self) -> str:
+ return self.E.dtype
+
+ @property
+ def in_dtype(self) -> str:
+ assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
+ return self.A.dtype
+
+ @property
+ def accum_dtype(self) -> str:
+ return self.C.dtype
+
+ @property
+ def A(self) -> tir.Buffer:
+ return self.gemm_sp_node.A
+
+ @property
+ def E(self) -> tir.Buffer:
+ return self.gemm_sp_node.E
+
+ @property
+ def B(self) -> tir.Buffer:
+ return self.gemm_sp_node.B
+
+ @property
+ def C(self) -> tir.Buffer:
+ return self.gemm_sp_node.C
+
+ @property
+ def ARegion(self) -> tir.PrimExpr:
+ return self.gemm_sp_node.ARegion
+
+ @property
+ def ERegion(self) -> tir.PrimExpr:
+ return self.gemm_sp_node.ERegion
+
+ @property
+ def BRegion(self) -> tir.PrimExpr:
+ return self.gemm_sp_node.BRegion
+
+ @property
+ def CRegion(self) -> tir.PrimExpr:
+ return self.gemm_sp_node.CRegion
+
+ @property
+ def stride_A(self) -> int:
+ return self.gemm_sp_node.stride_A
+
+ @property
+ def stride_B(self) -> int:
+ return self.gemm_sp_node.stride_B
+
+ @property
+ def offset_A(self) -> int:
+ return self.gemm_sp_node.offset_A
+
+ @property
+ def offset_B(self) -> int:
+ return self.gemm_sp_node.offset_B
+
+ @property
+ def clear_accum(self) -> bool:
+ return self.gemm_sp_node.clear_accum
+
+ @property
+ def k_pack(self) -> int:
+ return self.gemm_sp_node.k_pack
+
+ @property
+ def wg_wait(self) -> int:
+ return self.gemm_sp_node.wg_wait
+
+ @property
+ def policy(self) -> GemmWarpPolicy:
+ return self.gemm_sp_node.policy
diff --git a/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/tileop/gemm_sp/gemm_sp_mma.py
new file mode 100644
index 000000000..50a40bb91
--- /dev/null
+++ b/tilelang/tileop/gemm_sp/gemm_sp_mma.py
@@ -0,0 +1,247 @@
+from .gemm_sp_base import GemmSPBase
+from tilelang.layout import make_swizzled_layout
+from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
+from tilelang.utils.language import is_shared, is_fragment
+from tilelang import tvm as tvm
+from tvm.target import Target
+from tvm import tir
+from tilelang import language as T
+from tilelang.transform.simplify import _Simplify
+
+
+class GemmSPMMA(GemmSPBase):
+
+ def infer_layout(self, target: Target, thread_nums: int):
+ m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
+ False)
+ warp_row_tiles = int(self.M // m_warp)
+ warp_col_tiles = int(self.N // n_warp)
+ mma_emitter = SparseTensorCoreIntrinEmitter(
+ a_dtype=self.in_dtype,
+ e_dtype=self.e_dtype,
+ b_dtype=self.in_dtype,
+ accum_dtype=self.accum_dtype,
+ a_transposed=self.trans_A,
+ b_transposed=self.trans_B,
+ e_transposed=self.trans_E,
+ block_row_warps=m_warp,
+ block_col_warps=n_warp,
+ warp_row_tiles=warp_row_tiles,
+ warp_col_tiles=warp_col_tiles,
+ warp_k=self.K,
+ )
+ if self.is_gemm_ss():
+ return {
+ self.A: make_swizzled_layout(self.A),
+ self.B: make_swizzled_layout(self.B),
+ self.C: mma_emitter.make_mma_store_layout(self.C),
+ }
+ elif self.is_gemm_sr():
+ return {
+ self.A: make_swizzled_layout(self.A),
+ self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
+ self.C: mma_emitter.make_mma_store_layout(self.C),
+ }
+ elif self.is_gemm_rs():
+ return {
+ self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
+ self.B: make_swizzled_layout(self.B),
+ self.C: mma_emitter.make_mma_store_layout(self.C),
+ }
+ elif self.is_gemm_rr():
+ return {
+ self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
+ self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
+ self.C: mma_emitter.make_mma_store_layout(self.C),
+ }
+ else:
+ raise ValueError(
+ f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
+
+ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
+ m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
+ False)
+ warp_row_tiles = int(self.M // m_warp)
+ warp_col_tiles = int(self.N // n_warp)
+ mma_emitter = SparseTensorCoreIntrinEmitter(
+ a_dtype=self.in_dtype,
+ b_dtype=self.in_dtype,
+ e_dtype=self.e_dtype,
+ accum_dtype=self.accum_dtype,
+ a_transposed=self.trans_A,
+ b_transposed=self.trans_B,
+ e_transposed=self.trans_E,
+ block_row_warps=m_warp,
+ block_col_warps=n_warp,
+ warp_row_tiles=warp_row_tiles,
+ warp_col_tiles=warp_col_tiles,
+ warp_k=self.K,
+ thread_var=thread_var,
+ )
+
+ in_dtype = self.in_dtype
+ warp_rows = mma_emitter.warp_rows
+ warp_cols = mma_emitter.warp_cols
+ local_size_a = mma_emitter.local_size_a
+ local_size_e = mma_emitter.local_size_e
+ local_size_b = mma_emitter.local_size_b
+ micro_size_k = mma_emitter.micro_size_k
+ A_shared = self.A
+ E_shared = self.E
+ B_shared = self.B
+ C_local = self.C
+ assert micro_size_k <= self.K, f"K dimension {self.K} should be >= micro size k {micro_size_k}"
+ if self.is_gemm_ss():
+
+ @T.prim_func
+ def _gemm_ssr() -> None:
+ """
+ The inner macro that loads data from shared buffers A_shared and
+ B_shared into local fragments, then issues Tensor Core mma ops,
+ accumulating into C_local.
+ """
+ A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
+ E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
+ B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
+
+ for ki in T.serial(0, (self.K // micro_size_k)):
+ # Load A into fragment
+ mma_emitter.ldmatrix_a(
+ A_local,
+ A_shared,
+ ki,
+ )
+
+ # Load E into fragment
+ mma_emitter.ldmatrix_e(
+ E_local,
+ E_shared,
+ ki,
+ )
+
+ # Load B into fragment
+ mma_emitter.ldmatrix_b(
+ B_local,
+ B_shared,
+ ki,
+ )
+
+ # Perform Matrix Multiplication
+ mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
+
+ # Simplify to optimize the index computing
+ # Must inline let statements to simplify the analysis
+ return _Simplify(_gemm_ssr, inline_let=True)
+ elif self.is_gemm_sr():
+ B_local = self.B
+
+ @T.prim_func
+ def _gemm_srr() -> None:
+ """
+ The inner macro that loads data from shared buffers A_shared and
+ B_shared into local fragments, then issues Tensor Core mma ops,
+ accumulating into C_local.
+ """
+ A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
+ E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
+
+ for ki in T.serial(0, (self.K // micro_size_k)):
+
+ # Load A into fragment
+ mma_emitter.ldmatrix_a(
+ A_local,
+ A_shared,
+ ki,
+ )
+
+ # Load E into fragment
+ mma_emitter.ldmatrix_e(
+ E_local,
+ E_shared,
+ ki,
+ )
+
+ # Perform Matrix Multiplication
+ mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
+
+ # Simplify to optimize the index computing
+ # Must inline let statements to simplify the analysis
+ # alloc_buffers body
+ # insert into parent block
+ return _Simplify(_gemm_srr, inline_let=True)
+ elif self.is_gemm_rs():
+ A_local = self.A
+
+ @T.prim_func
+ def _gemm_rsr() -> None:
+ """
+ The inner macro that loads data from shared buffers A_shared and
+ B_shared into local fragments, then issues Tensor Core mma ops,
+ accumulating into C_local.
+ """
+ E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
+ B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
+
+ for ki in T.serial(0, (self.K // micro_size_k)):
+ # Load E into fragment
+ mma_emitter.ldmatrix_e(
+ E_local,
+ E_shared,
+ ki,
+ )
+
+ # Load B into fragment
+ mma_emitter.ldmatrix_b(
+ B_local,
+ B_shared,
+ ki,
+ )
+
+ # Perform Matrix Multiplication
+ mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
+
+ # Simplify to optimize the index computing
+ # Must inline let statements to simplify the analysis
+ return _Simplify(_gemm_rsr, inline_let=True)
+ elif self.is_gemm_rr():
+ A_local = self.A
+ B_local = self.B
+
+ @T.prim_func
+ def _gemm_rrr() -> None:
+ """
+ The inner macro that loads data from shared buffers A_shared and
+ B_shared into local fragments, then issues Tensor Core mma ops,
+ accumulating into C_local.
+ """
+ E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
+
+ for ki in T.serial(0, (self.K // micro_size_k)):
+ # Load E into fragment
+ mma_emitter.ldmatrix_e(
+ E_local,
+ E_shared,
+ ki,
+ )
+
+ # Perform Matrix Multiplication
+ mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
+
+ # Simplify to optimize the index computing
+ # Must inline let statements to simplify the analysis
+ return _Simplify(_gemm_rrr, inline_let=True)
+ else:
+ raise ValueError(
+ f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
+
+ def is_gemm_ss(self) -> bool:
+ return is_shared(self.A) and is_shared(self.B)
+
+ def is_gemm_sr(self) -> bool:
+ return is_shared(self.A) and is_fragment(self.B)
+
+ def is_gemm_rs(self) -> bool:
+ return is_fragment(self.A) and is_shared(self.B)
+
+ def is_gemm_rr(self) -> bool:
+ return is_fragment(self.A) and is_fragment(self.B)
diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py
index cd364b8bb..a7b17ad93 100644
--- a/tilelang/utils/sparse.py
+++ b/tilelang/utils/sparse.py
@@ -3,6 +3,7 @@
import torch
import warnings
from tilelang.contrib import nvcc
+from tilelang.utils.tensor import is_float8_dtype, fp8_remove_negative_zeros_
from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang import env
@@ -15,11 +16,10 @@
def _get_cached_lib():
name = 'compress_lib'
- cached_path = os.path.join(_CACHE_DIR, f"{name}.so")
- if os.path.exists(cached_path):
+ if os.path.exists(os.path.join(_CACHE_DIR, f"{name}.so")):
try:
- return _import_module_from_library(name, cached_path)
+ return _import_module_from_library(name, _CACHE_DIR, is_python_module=True)
except Exception:
# If loading fails, recompile
pass
@@ -88,7 +88,18 @@ def compress(A: torch.Tensor,
if compute_version >= (9, 0):
return compress_sm90(A, transposed=transposed, **kwargs)
elif compute_version >= (8, 0):
- return compress_sm80(A, transposed=transposed)
+ if transposed:
+ A = A.t().contiguous()
+ origin_dtype = A.dtype
+ if is_float8_dtype(origin_dtype):
+ fp8_remove_negative_zeros_(A)
+ A = A.view(torch.int8)
+ A_sp, E = compress_sm80(A, transposed=False)
+ if is_float8_dtype(origin_dtype):
+ A_sp = A_sp.view(origin_dtype)
+ if transposed:
+ A_sp = A_sp.t().contiguous()
+ return A_sp, E
else:
raise ValueError(f"Unsupported CUDA compute version: {compute_version}. "
"Supported versions are sm_80 and sm_90.")
@@ -105,6 +116,8 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
+ if dtype == torch.float32:
+ elem, group = 1, 2
tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
@@ -114,6 +127,36 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
return tensor.to(dtype) # dtype like float8 might not have randn kernel
+def randint_semi_sparse(M: int,
+ K: int,
+ low: int,
+ high: int,
+ dtype=torch.int32,
+ device='cuda',
+ transposed: bool = False):
+ """
+ Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension.
+ Args:
+ M (int): Number of rows
+ K (int): Number of columns
+ low (int): Lower bound of the random integers
+ high (int): Upper bound of the random integers
+ dtype: Data type of the tensor
+ device: Device to create the tensor on
+ transposed (bool): If True, returns a transposed tensor of shape (K, M)
+ """
+ elem, group = 2, 4
+ if dtype == torch.float32:
+ elem, group = 1, 2
+ tensor = torch.randint(low, high, (M, K), dtype=dtype, device=device).view(M, -1, group)
+ indice = tensor.topk(elem, dim=-1).indices
+ tensor.scatter_(-1, indice, 0)
+ tensor = tensor.view(M, K)
+ if transposed:
+ tensor = tensor.t().contiguous()
+ return tensor
+
+
def arange_semi_sparse(M: int,
K: int,
dtype=torch.float16,
@@ -129,6 +172,8 @@ def arange_semi_sparse(M: int,
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
+ if dtype == torch.float32:
+ elem, group = 1, 2
tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py
index b275708c4..b2905fb1b 100644
--- a/tilelang/utils/tensor.py
+++ b/tilelang/utils/tensor.py
@@ -5,6 +5,22 @@
import numpy as np
+def is_float8_dtype(dtype: torch.dtype) -> bool:
+ return dtype in {
+ torch.float8_e5m2,
+ torch.float8_e5m2fnuz,
+ torch.float8_e4m3fn,
+ torch.float8_e4m3fnuz,
+ }
+
+
+def fp8_remove_negative_zeros_(tensor: torch.Tensor):
+ assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype"
+ bits = tensor.view(torch.uint8)
+ zeros_mask = (tensor == 0)
+ bits[zeros_mask] = 0x00
+
+
class TensorSupplyType(Enum):
Integer = 1
Uniform = 2