Skip to content

Commit 8d7d7a1

Browse files
committed
[wip] Support hl.barrier for mega-kernels
stack-info: PR: #1151, branch: jansel/stack/228
1 parent b074c9e commit 8d7d7a1

22 files changed

+1302
-65
lines changed

examples/split_k_barrier.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import DEVICE
7+
from helion._testing import run_example
8+
from helion.autotuner import PowerOfTwoFragment
9+
import helion.language as hl
10+
11+
12+
@helion.kernel(static_shapes=True, dot_precision="ieee")
13+
def split_k_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
14+
"""
15+
Two-stage split-K matmul using hl.barrier(). The barrier approach
16+
gives deterministic results as opposed to the atomic_add approach.
17+
18+
Stage 1:
19+
- Split K into `split_k` contiguous chunks.
20+
- Each chunk computes a partial [tile_m, tile_n] product into its own slice of `tmp`.
21+
22+
Barrier:
23+
- Grid-wide barrier to ensure all partials are written before reduction.
24+
25+
Stage 2:
26+
- Reduce partials across the split dimension and write `out`.
27+
28+
Shapes:
29+
a: [M, K]
30+
b: [K, N]
31+
tmp: [M, N, split_k]
32+
out: [M, N]
33+
34+
Notes:
35+
- Static shapes keep codegen simpler.
36+
- `split_k` is fixed for clarity; autotuning could choose it instead.
37+
"""
38+
m, k = a.shape
39+
_, n = b.shape
40+
split_k = hl.register_tunable("split_k", PowerOfTwoFragment(16, 512, 64))
41+
block_k = helion.next_power_of_2(helion.cdiv(k, split_k))
42+
tmp = torch.empty((m, n, split_k), device=a.device, dtype=a.dtype)
43+
out = torch.empty((m, n), device=a.device, dtype=a.dtype)
44+
45+
for tile_m, tile_n, tile_k_outer in hl.tile(
46+
[m, n, k], block_size=[None, None, block_k]
47+
):
48+
acc = hl.zeros([tile_m, tile_n], device=a.device, dtype=a.dtype)
49+
for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end):
50+
acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n])
51+
# this could be a hl.atomic_add to avoid the barrier, but that would be non-determinstic
52+
tmp[tile_m, tile_n, tile_k_outer.id] = acc
53+
54+
hl.barrier()
55+
56+
for tile_m, tile_n in hl.tile([m, n]):
57+
out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1)
58+
59+
return out
60+
61+
62+
def check(m: int, k: int, n: int) -> None:
63+
a = torch.randn(m, k, device=DEVICE)
64+
b = torch.randn(n, k, device=DEVICE).T
65+
66+
run_example(
67+
split_k_matmul,
68+
torch.matmul,
69+
args=(a, b),
70+
atol=5e-1, # long reduction accumulate errors
71+
)
72+
73+
74+
def main() -> None:
75+
torch.manual_seed(0)
76+
check(16, 4096, 16)
77+
78+
79+
if __name__ == "__main__":
80+
main()

helion/_compiler/compile_environment.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from .. import exc
2323
from ..language.constexpr import ConstExpr
24-
from .loop_dependency_checker import LoopDependencyChecker
2524
from .source_location import SourceLocation
2625
from .source_location import current_location
2726
from .variable_origin import BlockSizeOrigin
@@ -102,18 +101,18 @@ def __init__(
102101
self.block_sizes: list[BlockSizeInfo] = []
103102
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
104103
self.config_spec = ConfigSpec()
105-
if settings.autotune_force_persistent:
106-
for pid_type in ("flat", "xyz"):
107-
self.config_spec.disallow_pid_type(pid_type)
108104
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
109105
collections.Counter()
110106
)
111107
self.specialized_vars: set[sympy.Symbol] = set()
112-
self.loop_dependency_checker = LoopDependencyChecker()
113108
self._symint_cache: dict[object, torch.SymInt] = {}
114109
self.device_load_count = (
115110
0 # Track number of loads in all device code for eviction policy tuning
116111
)
112+
if settings.autotune_force_persistent:
113+
for pid_type in ("flat", "xyz"):
114+
self.config_spec.disallow_pid_type(pid_type)
115+
self.has_barrier: bool = False
117116

118117
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
119118
from .device_function import contains_only_block_size_symbols
@@ -405,7 +404,6 @@ def __enter__(self) -> Self:
405404
assert getattr(tls, "env", None) is None, "CompileEnvironment already active"
406405
self.fake_mode.__enter__()
407406
tls.env = self
408-
self.loop_dependency_checker = LoopDependencyChecker()
409407
return self
410408

411409
def __exit__(

helion/_compiler/device_function.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ def __init__(self, val: int) -> None:
197197

198198

199199
class DeviceFunction:
200-
def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
200+
def __init__(
201+
self,
202+
name: str,
203+
config: Config,
204+
codegen: GenerateAST,
205+
) -> None:
201206
super().__init__()
202207
self.name = name
203208
self.config = config
@@ -659,6 +664,11 @@ def codegen_function_call(self) -> ast.AST:
659664
[
660665
f"num_warps={num_warps}",
661666
f"num_stages={self.config.num_stages}",
667+
*(
668+
["launch_cooperative_grid=True"]
669+
if CompileEnvironment.current().has_barrier
670+
else []
671+
),
662672
]
663673
+ [
664674
f"{x.removeprefix('_triton_config_')}={self.config[x]}"

helion/_compiler/device_ir.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .inductor_lowering import CodegenState
4949
from .inductor_lowering import codegen_call_with_graph
5050
from .inductor_lowering import prepare_graph_lowerings
51+
from .loop_dependency_checker import LoopDependencyChecker
5152
from .matmul_utils import tensor_matmul_replacement
5253
from .matmul_utils import torch_matmul_replacement
5354
from .node_masking import remove_unnecessary_masking
@@ -189,6 +190,8 @@ def codegen(self, state: CodegenState) -> list[object]:
189190

190191

191192
class RootGraphInfo(GraphInfo):
193+
phase_index: int = 0
194+
192195
@property
193196
def name(self) -> str:
194197
return f"root_graph_{self.graph_id}"
@@ -376,12 +379,22 @@ class RolledReductionInfo(NamedTuple):
376379
can_be_rolled_by_caller: bool
377380

378381

382+
@dataclasses.dataclass
383+
class KernelPhase:
384+
roots: list[int] # store root indices
385+
root_nodes: list[ast.For]
386+
loop_dependency_checker: LoopDependencyChecker = dataclasses.field(
387+
default_factory=LoopDependencyChecker
388+
)
389+
390+
379391
class DeviceIR:
380392
def __init__(self) -> None:
381393
super().__init__()
382394
self.graphs: list[GraphInfo] = []
383395
self.root_ids: list[int] = []
384396
self.rolled_reductions: list[RolledReductionInfo] = []
397+
self.phases: list[KernelPhase] = []
385398
self.grid_block_ids: list[list[int]] = []
386399

387400
def get_root(self, config: Config, graph_id: int) -> torch.fx.Graph:
@@ -435,6 +448,11 @@ def add_reduction_loop_graph(
435448
def add_root_graph(self, graph: torch.fx.Graph) -> None:
436449
self.root_ids.append(self.add_graph(graph, graph_info_cls=RootGraphInfo))
437450

451+
def phase_for_root(self, root_id: int) -> int:
452+
graph_info = self.graphs[self.root_ids[root_id]]
453+
assert isinstance(graph_info, RootGraphInfo)
454+
return graph_info.phase_index
455+
438456
def build_rolled_reductions(self) -> None:
439457
env = CompileEnvironment.current()
440458
rdims = [bs for bs in env.block_sizes if bs.reduction]
@@ -1274,6 +1292,10 @@ class WalkHostAST(NodeVisitor):
12741292
def __init__(self, device_ir: DeviceIR) -> None:
12751293
super().__init__()
12761294
self.device_ir = device_ir
1295+
self.root_index = 0
1296+
self.current_phase_roots: list[int] = []
1297+
self.phases: list[KernelPhase] = []
1298+
self.root_nodes: list[ast.For] = []
12771299

12781300
def visit_For(self, node: ast.For) -> None:
12791301
assert isinstance(node, ExtendedAST)
@@ -1292,9 +1314,44 @@ def visit_For(self, node: ast.For) -> None:
12921314
# pyrefly: ignore [missing-attribute]
12931315
block_ids = [inner.block_id]
12941316
self.device_ir.grid_block_ids.append(block_ids)
1317+
# store root index (position) not graph id
1318+
self.root_nodes.append(node)
1319+
self.current_phase_roots.append(len(self.device_ir.root_ids) - 1)
1320+
self.root_index += 1
12951321
else:
12961322
self.generic_visit(node)
12971323

1324+
def visit_Expr(self, node: ast.Expr) -> None:
1325+
# Record barrier placement between top-level loops.
1326+
from .type_propagation import BarrierResultType
1327+
1328+
assert isinstance(node, ExtendedAST)
1329+
assert isinstance(node.value, ExtendedAST)
1330+
is_barrier = isinstance(node.value._type_info, BarrierResultType)
1331+
1332+
if is_barrier:
1333+
if self.root_index == 0 or not self.current_phase_roots:
1334+
raise exc.BarrierOnlyAllowedAtTopLevel
1335+
self.phases.append(
1336+
KernelPhase(
1337+
roots=self.current_phase_roots,
1338+
root_nodes=[self.root_nodes[r] for r in self.current_phase_roots],
1339+
)
1340+
)
1341+
self.current_phase_roots = []
1342+
return
1343+
self.generic_visit(node)
1344+
1345+
def flush_phases(self) -> None:
1346+
if self.current_phase_roots:
1347+
self.phases.append(
1348+
KernelPhase(
1349+
roots=self.current_phase_roots,
1350+
root_nodes=[self.root_nodes[r] for r in self.current_phase_roots],
1351+
)
1352+
)
1353+
self.current_phase_roots = []
1354+
12981355

12991356
def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]:
13001357
"""Count the number of load and store operations in device code for autotuning.
@@ -1386,6 +1443,18 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
13861443
visitor = WalkHostAST(device_ir)
13871444
for stmt in func.body:
13881445
visitor.visit(stmt)
1446+
visitor.flush_phases()
1447+
device_ir.phases = visitor.phases
1448+
# Run dependency checks once, per phase, so codegen does not redo it per-config.
1449+
for phase in device_ir.phases:
1450+
checker = phase.loop_dependency_checker
1451+
for loop_node in phase.root_nodes:
1452+
checker.register_loop(loop_node)
1453+
for phase_idx, phase in enumerate(device_ir.phases):
1454+
for ridx in phase.roots:
1455+
graph_info = device_ir.graphs[device_ir.root_ids[ridx]]
1456+
assert isinstance(graph_info, RootGraphInfo)
1457+
graph_info.phase_index = phase_idx
13891458
# If there are no top-level device loops, we cannot generate a valid kernel.
13901459
# Raise a friendly error instead of emitting an empty Triton function body.
13911460
if len(device_ir.root_ids) == 0:

helion/_compiler/generate_ast.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .. import exc
1212
from ..language._decorators import is_api_func
13+
from ..runtime.config import Config
1314
from .ast_extension import ExtendedAST
1415
from .ast_extension import LoopType
1516
from .ast_extension import NodeVisitor
@@ -24,6 +25,7 @@
2425
from .helper_function import CodegenInterface
2526
from .inductor_lowering import CodegenState
2627
from .inductor_lowering import codegen_call_with_graph
28+
from .loop_dependency_checker import LoopDependencyChecker
2729
from .program_id import ForEachProgramID
2830
from .tile_strategy import DeviceLoopState
2931
from .variable_origin import ArgumentOrigin
@@ -35,6 +37,7 @@
3537

3638
from ..runtime import Config
3739
from .host_function import HostFunction
40+
from .loop_dependency_checker import LoopDependencyChecker
3841
from .tile_strategy import DeviceLoopOrGridState
3942
from .type_propagation import TensorType
4043

@@ -55,7 +58,11 @@ def __init__(self, func: HostFunction, config: Config) -> None:
5558
self.next_else_block: list[ast.AST] | None = None
5659

5760
# Now create device function and initialize CodegenInterface
58-
self.device_function = DeviceFunction(f"_helion_{func.name}", config, self)
61+
self.device_function = DeviceFunction(
62+
f"_helion_{func.name}",
63+
config,
64+
self,
65+
)
5966
CodegenInterface.__init__(self, self.device_function)
6067

6168
def offset_var(self, block_idx: int) -> str:
@@ -69,6 +76,10 @@ def mask_var(self, block_idx: int) -> str | None:
6976
return loops[-1].strategy.mask_var(block_idx)
7077
return None
7178

79+
def _phase_checker(self, root_id: int) -> LoopDependencyChecker:
80+
phase_idx = self.host_function.device_ir.phase_for_root(root_id)
81+
return self.host_function.device_ir.phases[phase_idx].loop_dependency_checker
82+
7283
def add_statement(self, stmt: ast.AST | str | None) -> None:
7384
if stmt is None:
7485
return
@@ -226,17 +237,20 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
226237
if node._loop_type == LoopType.GRID:
227238
assert not node.orelse
228239

240+
assert node._root_id is not None
241+
# Loop dependency checks were already run during lowering; phase checker kept for symmetry/debug.
242+
self._phase_checker(node._root_id)
243+
229244
if len(self.host_function.device_ir.root_ids) == 1:
230245
body = self.device_function.body
231246
else:
232247
assert len(self.host_function.device_ir.root_ids) > 1
233-
assert node._root_id is not None
234248
# Multiple top level for loops
235249

236250
if node._root_id == 0:
237251
self.device_function.set_pid(
238252
ForEachProgramID(
239-
self.device_function.new_var("pid_shared", dce=False)
253+
self.device_function.new_var("pid_shared", dce=False),
240254
)
241255
)
242256
self.device_function.body.extend(
@@ -309,6 +323,11 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
309323
# This ensures block size and rdim vars are defined in the correct order
310324
self.device_function.flush_deferred_rdim_defs(self)
311325

326+
if isinstance(self.device_function.pid, ForEachProgramID):
327+
self.device_function.pid.case_phases.append(
328+
self.host_function.device_ir.phase_for_root(node._root_id)
329+
)
330+
312331
# If we are in a multi top level loop, for all loops except for the last one
313332
# emit ifthenelse blocks
314333
if node._root_id < len(self.host_function.device_ir.root_ids) - 1:
@@ -476,6 +495,9 @@ def generate_ast(
476495
func: HostFunction, config: Config, emit_repro_caller: bool
477496
) -> ast.AST:
478497
with func:
498+
if len(func.device_ir.phases) > 1:
499+
if not str(config.pid_type).startswith("persistent"):
500+
raise exc.BarrierRequiresPersistent(config.pid_type)
479501
codegen = GenerateAST(func, config)
480502
with codegen.device_function:
481503
for stmt in func.body:

helion/_compiler/loop_dependency_checker.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,29 @@ class LoopDependencyChecker:
2222
def __init__(self) -> None:
2323
self.reads: set[str] = set()
2424
self.writes: set[str] = set()
25-
26-
def register_loop(self, loop_node: ast.For) -> None:
25+
self._barrier_after_root: set[int] = set()
26+
self._root_counter: int = 0
27+
self.disabled: bool = False
28+
29+
def insert_barrier_after_root(self, root_id: int) -> None:
30+
"""Record that a barrier separates root_id and root_id+1."""
31+
self._barrier_after_root.add(root_id)
32+
33+
def register_loop(self, loop_node: ast.For, root_id: int | None = None) -> None:
34+
if self.disabled:
35+
return
36+
current_root = root_id if root_id is not None else self._root_counter
37+
if (current_root - 1) in self._barrier_after_root:
38+
self.reads.clear()
39+
self.writes.clear()
40+
self._barrier_after_root.discard(current_root - 1)
2741
rw = ReadWrites.from_list(loop_node.body)
2842

2943
self._check_dependencies(rw)
3044

3145
self.reads |= set(rw.reads)
3246
self.writes |= set(rw.writes)
47+
self._root_counter = current_root + 1
3348

3449
def _check_dependencies(self, rw: ReadWrites) -> None:
3550
"""

0 commit comments

Comments
 (0)