4848from .inductor_lowering import CodegenState
4949from .inductor_lowering import codegen_call_with_graph
5050from .inductor_lowering import prepare_graph_lowerings
51+ from .loop_dependency_checker import LoopDependencyChecker
5152from .matmul_utils import tensor_matmul_replacement
5253from .matmul_utils import torch_matmul_replacement
5354from .node_masking import remove_unnecessary_masking
@@ -189,6 +190,8 @@ def codegen(self, state: CodegenState) -> list[object]:
189190
190191
191192class RootGraphInfo (GraphInfo ):
193+ phase_index : int = 0
194+
192195 @property
193196 def name (self ) -> str :
194197 return f"root_graph_{ self .graph_id } "
@@ -376,12 +379,22 @@ class RolledReductionInfo(NamedTuple):
376379 can_be_rolled_by_caller : bool
377380
378381
382+ @dataclasses .dataclass
383+ class KernelPhase :
384+ roots : list [int ] # store root indices
385+ root_nodes : list [ast .For ]
386+ loop_dependency_checker : LoopDependencyChecker = dataclasses .field (
387+ default_factory = LoopDependencyChecker
388+ )
389+
390+
379391class DeviceIR :
380392 def __init__ (self ) -> None :
381393 super ().__init__ ()
382394 self .graphs : list [GraphInfo ] = []
383395 self .root_ids : list [int ] = []
384396 self .rolled_reductions : list [RolledReductionInfo ] = []
397+ self .phases : list [KernelPhase ] = []
385398 self .grid_block_ids : list [list [int ]] = []
386399
387400 def get_root (self , config : Config , graph_id : int ) -> torch .fx .Graph :
@@ -435,6 +448,11 @@ def add_reduction_loop_graph(
435448 def add_root_graph (self , graph : torch .fx .Graph ) -> None :
436449 self .root_ids .append (self .add_graph (graph , graph_info_cls = RootGraphInfo ))
437450
451+ def phase_for_root (self , root_id : int ) -> int :
452+ graph_info = self .graphs [self .root_ids [root_id ]]
453+ assert isinstance (graph_info , RootGraphInfo )
454+ return graph_info .phase_index
455+
438456 def build_rolled_reductions (self ) -> None :
439457 env = CompileEnvironment .current ()
440458 rdims = [bs for bs in env .block_sizes if bs .reduction ]
@@ -1274,6 +1292,10 @@ class WalkHostAST(NodeVisitor):
12741292 def __init__ (self , device_ir : DeviceIR ) -> None :
12751293 super ().__init__ ()
12761294 self .device_ir = device_ir
1295+ self .root_index = 0
1296+ self .current_phase_roots : list [int ] = []
1297+ self .phases : list [KernelPhase ] = []
1298+ self .root_nodes : list [ast .For ] = []
12771299
12781300 def visit_For (self , node : ast .For ) -> None :
12791301 assert isinstance (node , ExtendedAST )
@@ -1292,9 +1314,44 @@ def visit_For(self, node: ast.For) -> None:
12921314 # pyrefly: ignore [missing-attribute]
12931315 block_ids = [inner .block_id ]
12941316 self .device_ir .grid_block_ids .append (block_ids )
1317+ # store root index (position) not graph id
1318+ self .root_nodes .append (node )
1319+ self .current_phase_roots .append (len (self .device_ir .root_ids ) - 1 )
1320+ self .root_index += 1
12951321 else :
12961322 self .generic_visit (node )
12971323
1324+ def visit_Expr (self , node : ast .Expr ) -> None :
1325+ # Record barrier placement between top-level loops.
1326+ from .type_propagation import BarrierResultType
1327+
1328+ assert isinstance (node , ExtendedAST )
1329+ assert isinstance (node .value , ExtendedAST )
1330+ is_barrier = isinstance (node .value ._type_info , BarrierResultType )
1331+
1332+ if is_barrier :
1333+ if self .root_index == 0 or not self .current_phase_roots :
1334+ raise exc .BarrierOnlyAllowedAtTopLevel
1335+ self .phases .append (
1336+ KernelPhase (
1337+ roots = self .current_phase_roots ,
1338+ root_nodes = [self .root_nodes [r ] for r in self .current_phase_roots ],
1339+ )
1340+ )
1341+ self .current_phase_roots = []
1342+ return
1343+ self .generic_visit (node )
1344+
1345+ def flush_phases (self ) -> None :
1346+ if self .current_phase_roots :
1347+ self .phases .append (
1348+ KernelPhase (
1349+ roots = self .current_phase_roots ,
1350+ root_nodes = [self .root_nodes [r ] for r in self .current_phase_roots ],
1351+ )
1352+ )
1353+ self .current_phase_roots = []
1354+
12981355
12991356def _count_device_loads_and_stores (device_ir : DeviceIR ) -> tuple [int , int , int ]:
13001357 """Count the number of load and store operations in device code for autotuning.
@@ -1386,6 +1443,18 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
13861443 visitor = WalkHostAST (device_ir )
13871444 for stmt in func .body :
13881445 visitor .visit (stmt )
1446+ visitor .flush_phases ()
1447+ device_ir .phases = visitor .phases
1448+ # Run dependency checks once, per phase, so codegen does not redo it per-config.
1449+ for phase in device_ir .phases :
1450+ checker = phase .loop_dependency_checker
1451+ for loop_node in phase .root_nodes :
1452+ checker .register_loop (loop_node )
1453+ for phase_idx , phase in enumerate (device_ir .phases ):
1454+ for ridx in phase .roots :
1455+ graph_info = device_ir .graphs [device_ir .root_ids [ridx ]]
1456+ assert isinstance (graph_info , RootGraphInfo )
1457+ graph_info .phase_index = phase_idx
13891458 # If there are no top-level device loops, we cannot generate a valid kernel.
13901459 # Raise a friendly error instead of emitting an empty Triton function body.
13911460 if len (device_ir .root_ids ) == 0 :
0 commit comments