Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 1 addition & 212 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,219 +46,8 @@
c = utils.c


@dataclasses.dataclass(frozen=True)
class Tiling:
"""A tiling expression describing a permutation of elements of an nd-array.

To apply one level of tiling to an array, each of the trailing dimensions (up
to the rank of the tile) is unfolded into two dimensions: first equal to the
ratio of the dimension size and the tile size, and second equal to the tile
size. Then, all newly unfolded minor dimensions are transposed to appear at
the end.

This expression describes multi-level tiling, by applying each element of
`tiles` in sequence to the array.

See https://openxla.org/xla/tiled_layout for a more detailed explanation.
"""
tiles: tuple[tuple[int, ...], ...]

def __post_init__(self):
if not self.tiles:
return
last_tile_rank = len(self.tiles[0])
for tile in self.tiles:
if len(tile) > last_tile_rank:
raise ValueError("Tiles must have a decreasing rank")
if not tile:
raise ValueError("Tiles must not be empty")
if any(d <= 0 for d in tile):
raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}")
last_tile_rank = len(tile)

def __str__(self):
return f"Tiling({''.join(map(str, self.tiles))})"

def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the shape of an array after tiling."""
orig_shape = shape
def fail():
raise ValueError(f"Tiling {self.tiles} does not apply to shape {orig_shape}")
for tile in self.tiles:
if len(tile) > len(shape):
fail()
untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):]
if any(s % t != 0 for s, t in zip(tiled_dims, tile)):
fail()
shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile)
return shape

def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the shape of an array before tiling from its tiled shape."""
orig_shape = shape
def fail():
raise ValueError(
f"shape {orig_shape} is not a valid result of applying tiling {self}."
)
for tile in reversed(self.tiles):
if len(tile) > len(shape):
fail()
untiled_dims = shape[:-2 * len(tile)]
tiled_dims = shape[-2 * len(tile):-len(tile)]
tiling_dims = shape[-len(tile):]
if tiling_dims != tile:
fail()
shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile)))
return shape

def canonicalize(self) -> Tiling:
"""Returns a canonicalized version of the tiling.

We define a tiling to be canonical if, at each step (except the first one,
which defines the base tile shape):

1. The tiling partitions at least one dimension in more than 1 tile. For
example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it
yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which
allows getting rid of the unnecessary `1` dimensions.
2. The leading dimensions of each tile are not `1`. If canonicalizing a
tile in this way leads to an empty tile, then the tile is given shape
`(1,)`---which is still a meaningful (final) tile. For example, the
tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape
`(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows
getting rid of the unnecessary `1` dimension, and yields a shape
`(8, 2, 4)`.
"""
if len(self.tiles) <= 1:
return self
Tiling = mgpu.dialect._mgpu_ext.Tiling

shape = self.tiles[0]
new_tiling = [self.tiles[0]]
for tile in self.tiles[1:]:
for i, d in enumerate(tile):
if d != 1:
canonical_tile = tile[i:]
break
else:
canonical_tile = (1,)
tiled_dims = shape[-len(canonical_tile):]
if tiled_dims == canonical_tile:
continue
shape = canonical_tile
new_tiling.append(canonical_tile)
return Tiling(tuple(new_tiling))

def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the strides of an array after tiling."""
for tile in self.tiles:
untiled, tiled = strides[:-len(tile)], strides[-len(tile):]
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
return strides

def tile_dimension(self, dim: int) -> tuple[bool, ...]:
"""Result is True whenever the tiled dim originated from the given input dim."""
tiling_rank = len(self.tiles[0])
if dim < 0 or dim >= tiling_rank:
raise ValueError(f"Invalid dimension {dim} for tiling {self}")
strides = [1] * tiling_rank
strides[dim] = 0
return tuple(s == 0 for s in self.tile_strides(tuple(strides)))

def remove_dimension(self, dim: int) -> Tiling:
"""Returns a tiling with the given dimension removed."""
tiling_rank = len(self.tiles[0])
if dim < 0 or dim >= tiling_rank:
raise ValueError(f"Invalid dimension {dim} for tiling {self}")
dim_in_tile = dim
tiles = []
last_tile_rank = len(self.tiles[0])
for t in self.tiles:
assert last_tile_rank >= len(t)
dim_in_tile -= last_tile_rank - len(t)
last_tile_rank = len(t)
if dim_in_tile >= 0:
t = t[:dim_in_tile] + t[dim_in_tile + 1:]
if not t: # If this tile is empty, all other tiles will be empty too.
break
tiles.append(t)
return Tiling(tuple(tiles))

def tile_nested_shape_strides(
self,
shape: tuple[tuple[int, ...], ...],
strides: tuple[tuple[int, ...], ...],
) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]:
"""A fused version of `tile_shape` and `tile_strides` for nested shapes.

By nested shape we mean that each logical dimension (i.e. each element of
shape/strides) is actually composed out of multiple physical dimensions.
For example, a row-major array of logical shape (128, 128) that is tiled
into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each
dim is split into two sub-dims) and nested strides of
((2 * 64 * 64, 64), (64 * 64, 1)).
"""
if len(shape) != len(strides):
raise ValueError(
f"Shape {shape} and strides {strides} must have the same length"
)
def fail_if(cond, shape=shape): # Capture shape now.
if cond:
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
for tile in self.tiles:
fail_if(len(tile) > len(shape))
untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):]
untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):]
major_dim_shapes, major_dim_strides = [], []
minor_dim_shapes, minor_dim_strides = [], []
for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides):
major_dim_shape_rev, major_dim_stride_rev = [], []
minor_dim_shape_rev, minor_dim_stride_rev = [], []
for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True):
if d < t: # We will need to tile more dims
fail_if(t % d != 0)
t //= d
minor_dim_shape_rev.append(d)
minor_dim_stride_rev.append(s)
elif t != 1: # Last dim to tile!
fail_if(d % t != 0)
minor_dim_shape_rev.append(t)
minor_dim_stride_rev.append(s)
if d != t: # No need to insert singleton dims.
major_dim_shape_rev.append(d // t)
major_dim_stride_rev.append(s * t)
t = 1
else: # Done tiling!
major_dim_shape_rev.append(d)
major_dim_stride_rev.append(s)
fail_if(t != 1)
major_dim_shapes.append(major_dim_shape_rev[::-1])
minor_dim_shapes.append(minor_dim_shape_rev[::-1])
major_dim_strides.append(major_dim_stride_rev[::-1])
minor_dim_strides.append(minor_dim_stride_rev[::-1])
shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) # type: ignore[arg-type]
strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) # type: ignore[arg-type]
return (
tuple(tuple(d) if d else (1,) for d in shape),
tuple(tuple(d) if d else (1,) for d in strides),
)

def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
for tile in self.tiles:
untiled, tiled = indices[:-len(tile)], indices[-len(tile):]
indices = (
*untiled,
*(i // t for i, t in zip(tiled, tile)),
*(i % t for i, t in zip(tiled, tile)),
)
return indices

def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
for tile in reversed(self.tiles):
untiled = indices[:-2 * len(tile)]
outer = indices[-2 * len(tile):-len(tile)]
inner = indices[-len(tile):]
indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile)))
return indices

def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]:
"""Like built-in enumerate, but returns negative indices into the sequence."""
Expand Down
38 changes: 37 additions & 1 deletion jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ package(
py_library(
name = "mosaic_gpu",
data = [":libmosaic_gpu_runtime.so"],
deps = [":_mosaic_gpu_ext"],
deps = [
":_mgpu_ext",
":_mosaic_gpu_ext",
],
)

cc_library(
Expand Down Expand Up @@ -379,12 +382,29 @@ nanobind_extension(
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/tsl/cuda:cudart",
],
)

nanobind_extension(
name = "_mgpu_ext",
srcs = ["mgpu_ext.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
deps = [
"//jaxlib/mosaic/gpu:tiled_layout",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@nanobind",
],
)

cc_binary(
name = "libmosaic_gpu_runtime.so",
srcs = ["runtime.cc"],
Expand All @@ -405,3 +425,19 @@ cc_library(
name = "library_paths",
hdrs = ["library_paths.h"],
)

cc_library(
name = "tiled_layout",
srcs = ["tiled_layout.cc"],
hdrs = ["tiled_layout.h"],
deps = ["@com_google_absl//absl/log:check"],
)

cc_test(
name = "tiled_layout_test",
srcs = ["tiled_layout_test.cc"],
deps = [
":tiled_layout",
"//testing/base/public:gunit_main",
],
)
112 changes: 112 additions & 0 deletions jaxlib/mosaic/gpu/mgpu_ext.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/* Copyright 2025 The JAX Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstddef>
#include <cstdint>
#include <new>
#include <stdexcept>
#include <tuple>
#include <vector>

#include "absl/hash/hash.h"
#include "nanobind/nanobind.h"
#include "nanobind/operators.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/tuple.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "jaxlib/mosaic/gpu/tiled_layout.h"

namespace mgpu = jax::mosaic::gpu;
namespace nb = nanobind;

namespace mosaic::gpu {

NB_MODULE(_mgpu_ext, m) {
nb::class_<mgpu::Tiling>(m, "Tiling")
.def(nb::init<std::vector<std::vector<int64_t>>>(), nb::arg("tiles"))
.def(
"tile_shape",
[](const mgpu::Tiling& self, const std::vector<int64_t>& shape) {
return nb::tuple(nb::cast(self.TileShape(shape)));
},
nb::arg("shape"))
.def(
"untile_shape",
[](const mgpu::Tiling& self, const std::vector<int64_t>& shape) {
return nb::tuple(nb::cast(self.UntileShape(shape)));
},
nb::arg("shape"))
.def(
"tile_strides",
[](const mgpu::Tiling& self, const std::vector<int64_t>& strides) {
return nb::tuple(nb::cast(self.TileStrides(strides)));
},
nb::arg("strides"))
.def(
"tile_indices",
[](const mgpu::Tiling& self, const std::vector<int64_t>& indices) {
return nb::tuple(nb::cast(self.TileIndices(indices)));
},
nb::arg("indices"))
.def(
"untile_indices",
[](const mgpu::Tiling& self, const std::vector<int64_t>& indices) {
return nb::tuple(nb::cast(self.UntileIndices(indices)));
},
nb::arg("indices"))
.def(
"tile_nested_shape_strides",
[](const mgpu::Tiling& self,
const std::vector<std::vector<int64_t>>& shape,
const std::vector<std::vector<int64_t>>& strides) {
auto [tiled_shape, tiled_strides] =
self.TileNestedShapeStrides(shape, strides);
nb::list shape_list;
for (const auto& s : tiled_shape) {
shape_list.append(nb::tuple(nb::cast(s)));
}
nb::list strides_list;
for (const auto& s : tiled_strides) {
strides_list.append(nb::tuple(nb::cast(s)));
}
return nb::make_tuple(nb::tuple(shape_list),
nb::tuple(strides_list));
},
nb::arg("shape"), nb::arg("strides"))
.def(
"tile_dimension",
[](const mgpu::Tiling& self, int64_t dim) {
return nb::tuple(nb::cast(self.TileDimension(dim)));
},
nb::arg("dim"))
.def("remove_dimension", &mgpu::Tiling::RemoveDimension, nb::arg("dim"))
.def("canonicalize", &mgpu::Tiling::Canonicalize)
.def_prop_ro("tiles",
[](const mgpu::Tiling& self) {
nb::list tiles_list;
for (const mgpu::Tiling::Tile& tile : self.tiles()) {
tiles_list.append(nb::tuple(nb::cast(tile)));
}
return nb::tuple(tiles_list);
})
.def("__str__", &mgpu::Tiling::ToString)
.def("__repr__", &mgpu::Tiling::ToString)
.def(nb::self == nb::self)
.def("__hash__", [](const mgpu::Tiling& self) {
return absl::Hash<mgpu::Tiling>{}(self);
});
}

} // namespace mosaic::gpu
Loading
Loading