Skip to content

Commit 8d21f85

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Construct APIs to modify Triton configs
Summary: Define set of preliminary APIs which translate Triton configs into TileIR configs: - Generate exhaustive set of TileIR configs, given no existing Triton configs - Convert Triton configs into an exhaustive set of TileIR configs, given any existing TileIR-related restrictions - Prune duplicate configs, i.e. those with all the same parameters except for `num_stages` and `num_warps` Differential Revision: D86124990
1 parent dbd8da1 commit 8d21f85

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

tritonbench/utils/tileir_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""TileIR utils for TritonBench."""
2+
3+
import itertools
4+
5+
import triton
6+
7+
8+
def generate_exhaustive_tileir_configs():
9+
return [
10+
triton.Config({"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, "occupancy": occ}, num_ctas=num_ctas)
11+
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
12+
[16, 32, 64, 128, 256], repeat=3
13+
)
14+
for occ in [1, 2]
15+
for num_ctas in [1, 2]
16+
]
17+
18+
19+
def convert_triton_to_tileir_configs(triton_configs) -> list[triton.Config]:
20+
tileir_configs = set()
21+
for config in triton_configs:
22+
if "occupancy" not in config.kwargs:
23+
for occ in [1, 2]:
24+
tileir_configs.add(
25+
triton.Config(
26+
{**config.kwargs, "occupancy": occ},
27+
num_warps=config.num_warps,
28+
num_stages=config.num_stages,
29+
num_ctas=config.num_ctas,
30+
maxnreg=config.maxnreg,
31+
pre_hook=config.pre_hook,
32+
ir_override=config.ir_override
33+
)
34+
)
35+
36+
for config in tileir_configs:
37+
for num_ctas in [1, 2]:
38+
tileir_configs.add(
39+
triton.Config(
40+
config.kwargs,
41+
num_warps=config.num_warps,
42+
num_stages=config.num_stages,
43+
num_ctas=num_ctas,
44+
maxnreg=config.maxnreg,
45+
pre_hook=config.pre_hook,
46+
ir_override=config.ir_override
47+
)
48+
)
49+
50+
return list(tileir_configs)
51+
52+
53+
def prune_duplicate_configs(configs, named_args, **kwargs) -> list[triton.Config]:
54+
"""
55+
Prune duplicate configs, i.e. those with all the same parameters except for num_warps and num_stages.
56+
"""
57+
pruned_configs = set()
58+
for config in configs:
59+
config.num_warps = None
60+
config.num_stages = None
61+
pruned_configs.add(config)
62+
return list(pruned_configs)

0 commit comments

Comments
 (0)