|
| 1 | +"""TileIR utils for TritonBench.""" |
| 2 | + |
| 3 | +import itertools |
| 4 | + |
| 5 | +import triton |
| 6 | + |
| 7 | + |
| 8 | +def generate_exhaustive_tileir_configs(): |
| 9 | + return [ |
| 10 | + triton.Config({"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, "occupancy": occ}, num_ctas=num_ctas) |
| 11 | + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( |
| 12 | + [16, 32, 64, 128, 256], repeat=3 |
| 13 | + ) |
| 14 | + for occ in [1, 2] |
| 15 | + for num_ctas in [1, 2] |
| 16 | + ] |
| 17 | + |
| 18 | + |
| 19 | +def convert_triton_to_tileir_configs(triton_configs) -> list[triton.Config]: |
| 20 | + tileir_configs = set() |
| 21 | + for config in triton_configs: |
| 22 | + if "occupancy" not in config.kwargs: |
| 23 | + for occ in [1, 2]: |
| 24 | + tileir_configs.add( |
| 25 | + triton.Config( |
| 26 | + {**config.kwargs, "occupancy": occ}, |
| 27 | + num_warps=config.num_warps, |
| 28 | + num_stages=config.num_stages, |
| 29 | + num_ctas=config.num_ctas, |
| 30 | + maxnreg=config.maxnreg, |
| 31 | + pre_hook=config.pre_hook, |
| 32 | + ir_override=config.ir_override |
| 33 | + ) |
| 34 | + ) |
| 35 | + |
| 36 | + for config in tileir_configs: |
| 37 | + for num_ctas in [1, 2]: |
| 38 | + tileir_configs.add( |
| 39 | + triton.Config( |
| 40 | + config.kwargs, |
| 41 | + num_warps=config.num_warps, |
| 42 | + num_stages=config.num_stages, |
| 43 | + num_ctas=num_ctas, |
| 44 | + maxnreg=config.maxnreg, |
| 45 | + pre_hook=config.pre_hook, |
| 46 | + ir_override=config.ir_override |
| 47 | + ) |
| 48 | + ) |
| 49 | + |
| 50 | + return list(tileir_configs) |
| 51 | + |
| 52 | + |
| 53 | +def prune_duplicate_configs(configs, named_args, **kwargs) -> list[triton.Config]: |
| 54 | + """ |
| 55 | + Prune duplicate configs, i.e. those with all the same parameters except for num_warps and num_stages. |
| 56 | + """ |
| 57 | + pruned_configs = set() |
| 58 | + for config in configs: |
| 59 | + config.num_warps = None |
| 60 | + config.num_stages = None |
| 61 | + pruned_configs.add(config) |
| 62 | + return list(pruned_configs) |
0 commit comments