Skip to content

Commit 7239d15

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Remove TraversalOrder from traverse_op.
We only use it in one place now, and therefore no longer need this parametrization. PiperOrigin-RevId: 826556517
1 parent dc7d56e commit 7239d15

File tree

2 files changed

+22
-39
lines changed

2 files changed

+22
-39
lines changed

jax/experimental/mosaic/gpu/inference_utils.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414

1515
"""Layout & transform inference convenience utils."""
1616

17-
from collections.abc import Callable, Sequence
18-
import enum
17+
from collections.abc import Sequence
1918
from functools import partial
2019
from typing import cast, Union
2120

22-
from jax._src.lib import mosaic_gpu_dialect as mgpu
2321
from jax._src.lib.mlir import ir
2422

2523
from . import fragmented_array as fa
@@ -303,37 +301,3 @@ def is_mma_layout(layout: fa.FragmentedLayout) -> bool:
303301
return columns % 16 == 0 and (
304302
layout == tcgen05.fa_m64_collective_layout(columns)
305303
)
306-
307-
308-
class TraversalOrder(enum.Enum):
309-
"""Traversal orders with respect to the data flow for IR."""
310-
311-
FORWARD = 1
312-
BACKWARDS = 2
313-
314-
315-
def traverse_op(
316-
op: ir.OpView,
317-
callback: Callable[[ir.OpView], None],
318-
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
319-
do_not_recurse_into_ops: tuple[type, ...] = (mgpu.CustomPrimitiveOp,),
320-
):
321-
"""Traverses the operation and applies the callback in the given order.
322-
323-
If do_not_recurse_into_ops is provided, the callback will be executed on these
324-
ops, but any regions they might have will not be traversed.
325-
"""
326-
callback(op)
327-
if not isinstance(op, do_not_recurse_into_ops):
328-
# The block of a mosaic_gpu.custom_primitive op is already lowered so it
329-
# should not be traversed.
330-
for region in op.operation.regions:
331-
for block in region:
332-
if traversal_order == TraversalOrder.FORWARD:
333-
ops_to_traverse = list(block)
334-
else:
335-
ops_to_traverse = reversed(list(block)) # type: ignore
336-
for block_op in ops_to_traverse:
337-
traverse_op(
338-
block_op, callback, traversal_order, do_not_recurse_into_ops
339-
)

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,25 @@ def is_terminator(op: ir.OpView) -> bool:
18501850
return isinstance(op, (scf.YieldOp, scf.ConditionOp))
18511851

18521852

1853+
def traverse_op(
1854+
op: ir.OpView,
1855+
callback: Callable[[ir.OpView], None],
1856+
):
1857+
"""Traverses the operation and applies the callback in pre-order fashion.
1858+
1859+
Skips recursing into `mgpu.CustomPrimitiveOp`s, and assumes that the values
1860+
iterated on are not being modified.
1861+
"""
1862+
callback(op)
1863+
# The block of a mosaic_gpu.custom_primitive op is already lowered so it
1864+
# should not be traversed.
1865+
if not isinstance(op, mgpu.CustomPrimitiveOp):
1866+
for region in op.operation.regions:
1867+
for block in region:
1868+
for block_op in block.operations:
1869+
traverse_op(block_op, callback)
1870+
1871+
18531872
def infer_layout(module: ir.Module):
18541873
"""Infers layouts for the given module.
18551874
@@ -1891,7 +1910,7 @@ def gather_equations(op: ir.Operation):
18911910
hints.extend(op_hints)
18921911

18931912
for op in module.body:
1894-
inference_utils.traverse_op(op, gather_equations)
1913+
traverse_op(op, gather_equations)
18951914

18961915
if isinstance(global_equation_system, eqns.Unsatisfiable):
18971916
raise ValueError(
@@ -1933,4 +1952,4 @@ def gather_equations(op: ir.Operation):
19331952

19341953
# Sanity check: ensure that all ops have the right number of in/out layouts.
19351954
for op in module.body:
1936-
inference_utils.traverse_op(op, _ensure_all_layouts_are_set)
1955+
traverse_op(op, _ensure_all_layouts_are_set)

0 commit comments

Comments
 (0)