Skip to content

Commit 7fa0603

Browse files
[XLA:MGPU] Port Tiling to C++.
PiperOrigin-RevId: 834710515
1 parent 987a025 commit 7fa0603

File tree

13 files changed

+756
-215
lines changed

13 files changed

+756
-215
lines changed

jax/_src/lib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def _xla_gc_callback(*args):
145145
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401
146146

147147
import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error # noqa: F401
148+
import jaxlib.mosaic.python.mgpu_ext as mgpu_ext # pytype: disable=import-error # noqa: F401
148149
import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401
149150

150151
# TODO(rocm): check if we need the same for rocm.

jax/experimental/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ py_library_providing_imports_info(
287287
"//jaxlib/mlir:scf_dialect",
288288
"//jaxlib/mlir:vector_dialect",
289289
"//jaxlib/mosaic/python:gpu_dialect",
290+
"//jaxlib/mosaic/python:mgpu_ext",
290291
] + py_deps("absl-all") + py_deps("numpy"),
291292
)
292293

jax/experimental/mosaic/gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from jax import ShapeDtypeStruct as ShapeDtypeStruct
1717
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
18+
from jax._src.lib import mgpu_ext as ext # noqa: F401
1819

1920
# The imports below shadow the module, so we need to rename it.
2021
from . import wgmma as _wgmma # noqa: F401

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 1 addition & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -46,219 +46,8 @@
4646
c = utils.c
4747

4848

49-
@dataclasses.dataclass(frozen=True)
50-
class Tiling:
51-
"""A tiling expression describing a permutation of elements of an nd-array.
52-
53-
To apply one level of tiling to an array, each of the trailing dimensions (up
54-
to the rank of the tile) is unfolded into two dimensions: first equal to the
55-
ratio of the dimension size and the tile size, and second equal to the tile
56-
size. Then, all newly unfolded minor dimensions are transposed to appear at
57-
the end.
58-
59-
This expression describes multi-level tiling, by applying each element of
60-
`tiles` in sequence to the array.
61-
62-
See https://openxla.org/xla/tiled_layout for a more detailed explanation.
63-
"""
64-
tiles: tuple[tuple[int, ...], ...]
65-
66-
def __post_init__(self):
67-
if not self.tiles:
68-
return
69-
last_tile_rank = len(self.tiles[0])
70-
for tile in self.tiles:
71-
if len(tile) > last_tile_rank:
72-
raise ValueError("Tiles must have a decreasing rank")
73-
if not tile:
74-
raise ValueError("Tiles must not be empty")
75-
if any(d <= 0 for d in tile):
76-
raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}")
77-
last_tile_rank = len(tile)
78-
79-
def __str__(self):
80-
return f"Tiling({''.join(map(str, self.tiles))})"
81-
82-
def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
83-
"""Computes the shape of an array after tiling."""
84-
orig_shape = shape
85-
def fail():
86-
raise ValueError(f"Tiling {self.tiles} does not apply to shape {orig_shape}")
87-
for tile in self.tiles:
88-
if len(tile) > len(shape):
89-
fail()
90-
untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):]
91-
if any(s % t != 0 for s, t in zip(tiled_dims, tile)):
92-
fail()
93-
shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile)
94-
return shape
95-
96-
def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
97-
"""Computes the shape of an array before tiling from its tiled shape."""
98-
orig_shape = shape
99-
def fail():
100-
raise ValueError(
101-
f"shape {orig_shape} is not a valid result of applying tiling {self}."
102-
)
103-
for tile in reversed(self.tiles):
104-
if len(tile) > len(shape):
105-
fail()
106-
untiled_dims = shape[:-2 * len(tile)]
107-
tiled_dims = shape[-2 * len(tile):-len(tile)]
108-
tiling_dims = shape[-len(tile):]
109-
if tiling_dims != tile:
110-
fail()
111-
shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile)))
112-
return shape
113-
114-
def canonicalize(self) -> Tiling:
115-
"""Returns a canonicalized version of the tiling.
116-
117-
We define a tiling to be canonical if, at each step (except the first one,
118-
which defines the base tile shape):
119-
120-
1. The tiling partitions at least one dimension in more than 1 tile. For
121-
example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it
122-
yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which
123-
allows getting rid of the unnecessary `1` dimensions.
124-
2. The leading dimensions of each tile are not `1`. If canonicalizing a
125-
tile in this way leads to an empty tile, then the tile is given shape
126-
`(1,)`---which is still a meaningful (final) tile. For example, the
127-
tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape
128-
`(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows
129-
getting rid of the unnecessary `1` dimension, and yields a shape
130-
`(8, 2, 4)`.
131-
"""
132-
if len(self.tiles) <= 1:
133-
return self
49+
Tiling = mgpu.ext.Tiling
13450

135-
shape = self.tiles[0]
136-
new_tiling = [self.tiles[0]]
137-
for tile in self.tiles[1:]:
138-
for i, d in enumerate(tile):
139-
if d != 1:
140-
canonical_tile = tile[i:]
141-
break
142-
else:
143-
canonical_tile = (1,)
144-
tiled_dims = shape[-len(canonical_tile):]
145-
if tiled_dims == canonical_tile:
146-
continue
147-
shape = canonical_tile
148-
new_tiling.append(canonical_tile)
149-
return Tiling(tuple(new_tiling))
150-
151-
def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]:
152-
"""Computes the strides of an array after tiling."""
153-
for tile in self.tiles:
154-
untiled, tiled = strides[:-len(tile)], strides[-len(tile):]
155-
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
156-
return strides
157-
158-
def tile_dimension(self, dim: int) -> tuple[bool, ...]:
159-
"""Result is True whenever the tiled dim originated from the given input dim."""
160-
tiling_rank = len(self.tiles[0])
161-
if dim < 0 or dim >= tiling_rank:
162-
raise ValueError(f"Invalid dimension {dim} for tiling {self}")
163-
strides = [1] * tiling_rank
164-
strides[dim] = 0
165-
return tuple(s == 0 for s in self.tile_strides(tuple(strides)))
166-
167-
def remove_dimension(self, dim: int) -> Tiling:
168-
"""Returns a tiling with the given dimension removed."""
169-
tiling_rank = len(self.tiles[0])
170-
if dim < 0 or dim >= tiling_rank:
171-
raise ValueError(f"Invalid dimension {dim} for tiling {self}")
172-
dim_in_tile = dim
173-
tiles = []
174-
last_tile_rank = len(self.tiles[0])
175-
for t in self.tiles:
176-
assert last_tile_rank >= len(t)
177-
dim_in_tile -= last_tile_rank - len(t)
178-
last_tile_rank = len(t)
179-
if dim_in_tile >= 0:
180-
t = t[:dim_in_tile] + t[dim_in_tile + 1:]
181-
if not t: # If this tile is empty, all other tiles will be empty too.
182-
break
183-
tiles.append(t)
184-
return Tiling(tuple(tiles))
185-
186-
def tile_nested_shape_strides(
187-
self,
188-
shape: tuple[tuple[int, ...], ...],
189-
strides: tuple[tuple[int, ...], ...],
190-
) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]:
191-
"""A fused version of `tile_shape` and `tile_strides` for nested shapes.
192-
193-
By nested shape we mean that each logical dimension (i.e. each element of
194-
shape/strides) is actually composed out of multiple physical dimensions.
195-
For example, a row-major array of logical shape (128, 128) that is tiled
196-
into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each
197-
dim is split into two sub-dims) and nested strides of
198-
((2 * 64 * 64, 64), (64 * 64, 1)).
199-
"""
200-
if len(shape) != len(strides):
201-
raise ValueError(
202-
f"Shape {shape} and strides {strides} must have the same length"
203-
)
204-
def fail_if(cond, shape=shape): # Capture shape now.
205-
if cond:
206-
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
207-
for tile in self.tiles:
208-
fail_if(len(tile) > len(shape))
209-
untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):]
210-
untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):]
211-
major_dim_shapes, major_dim_strides = [], []
212-
minor_dim_shapes, minor_dim_strides = [], []
213-
for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides):
214-
major_dim_shape_rev, major_dim_stride_rev = [], []
215-
minor_dim_shape_rev, minor_dim_stride_rev = [], []
216-
for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True):
217-
if d < t: # We will need to tile more dims
218-
fail_if(t % d != 0)
219-
t //= d
220-
minor_dim_shape_rev.append(d)
221-
minor_dim_stride_rev.append(s)
222-
elif t != 1: # Last dim to tile!
223-
fail_if(d % t != 0)
224-
minor_dim_shape_rev.append(t)
225-
minor_dim_stride_rev.append(s)
226-
if d != t: # No need to insert singleton dims.
227-
major_dim_shape_rev.append(d // t)
228-
major_dim_stride_rev.append(s * t)
229-
t = 1
230-
else: # Done tiling!
231-
major_dim_shape_rev.append(d)
232-
major_dim_stride_rev.append(s)
233-
fail_if(t != 1)
234-
major_dim_shapes.append(major_dim_shape_rev[::-1])
235-
minor_dim_shapes.append(minor_dim_shape_rev[::-1])
236-
major_dim_strides.append(major_dim_stride_rev[::-1])
237-
minor_dim_strides.append(minor_dim_stride_rev[::-1])
238-
shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) # type: ignore[arg-type]
239-
strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) # type: ignore[arg-type]
240-
return (
241-
tuple(tuple(d) if d else (1,) for d in shape),
242-
tuple(tuple(d) if d else (1,) for d in strides),
243-
)
244-
245-
def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
246-
for tile in self.tiles:
247-
untiled, tiled = indices[:-len(tile)], indices[-len(tile):]
248-
indices = (
249-
*untiled,
250-
*(i // t for i, t in zip(tiled, tile)),
251-
*(i % t for i, t in zip(tiled, tile)),
252-
)
253-
return indices
254-
255-
def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
256-
for tile in reversed(self.tiles):
257-
untiled = indices[:-2 * len(tile)]
258-
outer = indices[-2 * len(tile):-len(tile)]
259-
inner = indices[-len(tile):]
260-
indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile)))
261-
return indices
26251

26352
def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]:
26453
"""Like built-in enumerate, but returns negative indices into the sequence."""

jaxlib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ pytype_strict_library(
9090
"//jaxlib/mlir/_mlir_libs:_jax_mlir_ext",
9191
"//jaxlib/mosaic",
9292
"//jaxlib/mosaic/python:gpu_dialect",
93+
"//jaxlib/mosaic/python:mgpu_ext",
9394
"//jaxlib/mosaic/python:tpu_dialect",
9495
"//jaxlib/triton",
9596
"@xla//xla/python:_profile_data",

jaxlib/mosaic/gpu/BUILD

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ package(
2626
py_library(
2727
name = "mosaic_gpu",
2828
data = [":libmosaic_gpu_runtime.so"],
29-
deps = [":_mosaic_gpu_ext"],
29+
deps = [
30+
":_mgpu_ext",
31+
":_mosaic_gpu_ext",
32+
],
3033
)
3134

3235
cc_library(
@@ -379,12 +382,29 @@ nanobind_extension(
379382
"//jaxlib:kernel_nanobind_helpers",
380383
"//jaxlib/cuda:cuda_vendor",
381384
"@com_google_absl//absl/cleanup",
385+
"@com_google_absl//absl/hash",
382386
"@com_google_absl//absl/strings",
383387
"@nanobind",
384388
"@xla//xla/tsl/cuda:cudart",
385389
],
386390
)
387391

392+
nanobind_extension(
393+
name = "_mgpu_ext",
394+
srcs = ["mgpu_ext.cc"],
395+
copts = [
396+
"-fexceptions",
397+
"-fno-strict-aliasing",
398+
],
399+
deps = [
400+
"//jaxlib/mosaic/gpu:tiled_layout",
401+
"@com_google_absl//absl/cleanup",
402+
"@com_google_absl//absl/hash",
403+
"@com_google_absl//absl/strings",
404+
"@nanobind",
405+
],
406+
)
407+
388408
cc_binary(
389409
name = "libmosaic_gpu_runtime.so",
390410
srcs = ["runtime.cc"],
@@ -405,3 +425,19 @@ cc_library(
405425
name = "library_paths",
406426
hdrs = ["library_paths.h"],
407427
)
428+
429+
cc_library(
430+
name = "tiled_layout",
431+
srcs = ["tiled_layout.cc"],
432+
hdrs = ["tiled_layout.h"],
433+
deps = ["@com_google_absl//absl/log:check"],
434+
)
435+
436+
cc_test(
437+
name = "tiled_layout_test",
438+
srcs = ["tiled_layout_test.cc"],
439+
deps = [
440+
":tiled_layout",
441+
"//testing/base/public:gunit_main",
442+
],
443+
)

0 commit comments

Comments
 (0)