diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 7a4d21dbc45f..613707e3b46c 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -36,6 +36,9 @@ from . import utils +import sys +print(dir(mgpu.dialect), file=sys.stderr) + T = TypeVar("T") WARPGROUP_SIZE = utils.WARPGROUP_SIZE @@ -46,219 +49,8 @@ c = utils.c -@dataclasses.dataclass(frozen=True) -class Tiling: - """A tiling expression describing a permutation of elements of an nd-array. - - To apply one level of tiling to an array, each of the trailing dimensions (up - to the rank of the tile) is unfolded into two dimensions: first equal to the - ratio of the dimension size and the tile size, and second equal to the tile - size. Then, all newly unfolded minor dimensions are transposed to appear at - the end. - - This expression describes multi-level tiling, by applying each element of - `tiles` in sequence to the array. - - See https://openxla.org/xla/tiled_layout for a more detailed explanation. - """ - tiles: tuple[tuple[int, ...], ...] +Tiling = mgpu.dialect.Tiling - def __post_init__(self): - if not self.tiles: - return - last_tile_rank = len(self.tiles[0]) - for tile in self.tiles: - if len(tile) > last_tile_rank: - raise ValueError("Tiles must have a decreasing rank") - if not tile: - raise ValueError("Tiles must not be empty") - if any(d <= 0 for d in tile): - raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") - last_tile_rank = len(tile) - - def __str__(self): - return f"Tiling({''.join(map(str, self.tiles))})" - - def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: - """Computes the shape of an array after tiling.""" - orig_shape = shape - def fail(): - raise ValueError(f"Tiling {self.tiles} does not apply to shape {orig_shape}") - for tile in self.tiles: - if len(tile) > len(shape): - fail() - untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):] - if any(s % t != 0 for s, t in zip(tiled_dims, tile)): - fail() - shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile) - return shape - - def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: - """Computes the shape of an array before tiling from its tiled shape.""" - orig_shape = shape - def fail(): - raise ValueError( - f"shape {orig_shape} is not a valid result of applying tiling {self}." - ) - for tile in reversed(self.tiles): - if len(tile) > len(shape): - fail() - untiled_dims = shape[:-2 * len(tile)] - tiled_dims = shape[-2 * len(tile):-len(tile)] - tiling_dims = shape[-len(tile):] - if tiling_dims != tile: - fail() - shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) - return shape - - def canonicalize(self) -> Tiling: - """Returns a canonicalized version of the tiling. - - We define a tiling to be canonical if, at each step (except the first one, - which defines the base tile shape): - - 1. The tiling partitions at least one dimension in more than 1 tile. For - example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it - yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which - allows getting rid of the unnecessary `1` dimensions. - 2. The leading dimensions of each tile are not `1`. If canonicalizing a - tile in this way leads to an empty tile, then the tile is given shape - `(1,)`---which is still a meaningful (final) tile. For example, the - tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape - `(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows - getting rid of the unnecessary `1` dimension, and yields a shape - `(8, 2, 4)`. - """ - if len(self.tiles) <= 1: - return self - - shape = self.tiles[0] - new_tiling = [self.tiles[0]] - for tile in self.tiles[1:]: - for i, d in enumerate(tile): - if d != 1: - canonical_tile = tile[i:] - break - else: - canonical_tile = (1,) - tiled_dims = shape[-len(canonical_tile):] - if tiled_dims == canonical_tile: - continue - shape = canonical_tile - new_tiling.append(canonical_tile) - return Tiling(tuple(new_tiling)) - - def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: - """Computes the strides of an array after tiling.""" - for tile in self.tiles: - untiled, tiled = strides[:-len(tile)], strides[-len(tile):] - strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) - return strides - - def tile_dimension(self, dim: int) -> tuple[bool, ...]: - """Result is True whenever the tiled dim originated from the given input dim.""" - tiling_rank = len(self.tiles[0]) - if dim < 0 or dim >= tiling_rank: - raise ValueError(f"Invalid dimension {dim} for tiling {self}") - strides = [1] * tiling_rank - strides[dim] = 0 - return tuple(s == 0 for s in self.tile_strides(tuple(strides))) - - def remove_dimension(self, dim: int) -> Tiling: - """Returns a tiling with the given dimension removed.""" - tiling_rank = len(self.tiles[0]) - if dim < 0 or dim >= tiling_rank: - raise ValueError(f"Invalid dimension {dim} for tiling {self}") - dim_in_tile = dim - tiles = [] - last_tile_rank = len(self.tiles[0]) - for t in self.tiles: - assert last_tile_rank >= len(t) - dim_in_tile -= last_tile_rank - len(t) - last_tile_rank = len(t) - if dim_in_tile >= 0: - t = t[:dim_in_tile] + t[dim_in_tile + 1:] - if not t: # If this tile is empty, all other tiles will be empty too. - break - tiles.append(t) - return Tiling(tuple(tiles)) - - def tile_nested_shape_strides( - self, - shape: tuple[tuple[int, ...], ...], - strides: tuple[tuple[int, ...], ...], - ) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]: - """A fused version of `tile_shape` and `tile_strides` for nested shapes. - - By nested shape we mean that each logical dimension (i.e. each element of - shape/strides) is actually composed out of multiple physical dimensions. - For example, a row-major array of logical shape (128, 128) that is tiled - into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each - dim is split into two sub-dims) and nested strides of - ((2 * 64 * 64, 64), (64 * 64, 1)). - """ - if len(shape) != len(strides): - raise ValueError( - f"Shape {shape} and strides {strides} must have the same length" - ) - def fail_if(cond, shape=shape): # Capture shape now. - if cond: - raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") - for tile in self.tiles: - fail_if(len(tile) > len(shape)) - untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):] - untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):] - major_dim_shapes, major_dim_strides = [], [] - minor_dim_shapes, minor_dim_strides = [], [] - for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides): - major_dim_shape_rev, major_dim_stride_rev = [], [] - minor_dim_shape_rev, minor_dim_stride_rev = [], [] - for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True): - if d < t: # We will need to tile more dims - fail_if(t % d != 0) - t //= d - minor_dim_shape_rev.append(d) - minor_dim_stride_rev.append(s) - elif t != 1: # Last dim to tile! - fail_if(d % t != 0) - minor_dim_shape_rev.append(t) - minor_dim_stride_rev.append(s) - if d != t: # No need to insert singleton dims. - major_dim_shape_rev.append(d // t) - major_dim_stride_rev.append(s * t) - t = 1 - else: # Done tiling! - major_dim_shape_rev.append(d) - major_dim_stride_rev.append(s) - fail_if(t != 1) - major_dim_shapes.append(major_dim_shape_rev[::-1]) - minor_dim_shapes.append(minor_dim_shape_rev[::-1]) - major_dim_strides.append(major_dim_stride_rev[::-1]) - minor_dim_strides.append(minor_dim_stride_rev[::-1]) - shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) # type: ignore[arg-type] - strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) # type: ignore[arg-type] - return ( - tuple(tuple(d) if d else (1,) for d in shape), - tuple(tuple(d) if d else (1,) for d in strides), - ) - - def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: - for tile in self.tiles: - untiled, tiled = indices[:-len(tile)], indices[-len(tile):] - indices = ( - *untiled, - *(i // t for i, t in zip(tiled, tile)), - *(i % t for i, t in zip(tiled, tile)), - ) - return indices - - def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: - for tile in reversed(self.tiles): - untiled = indices[:-2 * len(tile)] - outer = indices[-2 * len(tile):-len(tile)] - inner = indices[-len(tile):] - indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile))) - return indices def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: """Like built-in enumerate, but returns negative indices into the sequence.""" diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 7b3411c65439..c32c27db25a6 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -133,6 +133,7 @@ nanobind_pywrap_extension( copts = COPTS, deps = [ "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "//jaxlib/mosaic/gpu:tiled_layout", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index 42ee79ed43ca..787d52cb0aa4 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -16,13 +16,20 @@ limitations under the License. #include #include +#include "absl/hash/hash.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep #include "nanobind/nanobind.h" +#include "nanobind/operators.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" +#include "jaxlib/mosaic/gpu/tiled_layout.h" namespace nb = nanobind; +namespace mgpu = jax::mosaic::gpu; NB_MODULE(_mosaic_gpu_ext, m) { m.def( @@ -143,4 +150,78 @@ NB_MODULE(_mosaic_gpu_ext, m) { .def_property_readonly("swizzle", [](MlirAttribute self) { return mlirMosaicGpuSwizzleTransformAttrGetSwizzle(self); }); + + nb::class_(m, "Tiling") + .def(nb::init>>(), nb::arg("tiles")) + .def( + "tile_shape", + [](const mgpu::Tiling& self, const std::vector& shape) { + return nb::tuple(nb::cast(self.TileShape(shape))); + }, + nb::arg("shape")) + .def( + "untile_shape", + [](const mgpu::Tiling& self, const std::vector& shape) { + return nb::tuple(nb::cast(self.UntileShape(shape))); + }, + nb::arg("shape")) + .def( + "tile_strides", + [](const mgpu::Tiling& self, const std::vector& strides) { + return nb::tuple(nb::cast(self.TileStrides(strides))); + }, + nb::arg("strides")) + .def( + "tile_indices", + [](const mgpu::Tiling& self, const std::vector& indices) { + return nb::tuple(nb::cast(self.TileIndices(indices))); + }, + nb::arg("indices")) + .def( + "untile_indices", + [](const mgpu::Tiling& self, const std::vector& indices) { + return nb::tuple(nb::cast(self.UntileIndices(indices))); + }, + nb::arg("indices")) + .def( + "tile_nested_shape_strides", + [](const mgpu::Tiling& self, + const std::vector>& shape, + const std::vector>& strides) { + auto [tiled_shape, tiled_strides] = + self.TileNestedShapeStrides(shape, strides); + nb::list shape_list; + for (const auto& s : tiled_shape) { + shape_list.append(nb::tuple(nb::cast(s))); + } + nb::list strides_list; + for (const auto& s : tiled_strides) { + strides_list.append(nb::tuple(nb::cast(s))); + } + return nb::make_tuple(nb::tuple(shape_list), + nb::tuple(strides_list)); + }, + nb::arg("shape"), nb::arg("strides")) + .def( + "tile_dimension", + [](const mgpu::Tiling& self, int64_t dim) { + return nb::tuple(nb::cast(self.TileDimension(dim))); + }, + nb::arg("dim")) + .def("remove_dimension", &mgpu::Tiling::RemoveDimension, nb::arg("dim")) + .def("canonicalize", &mgpu::Tiling::Canonicalize) + .def_prop_ro("tiles", + [](const mgpu::Tiling& self) { + nb::list tiles_list; + for (const mgpu::Tiling::Tile& tile : self.tiles()) { + tiles_list.append(nb::tuple(nb::cast(tile))); + } + return nb::tuple(tiles_list); + }) + .def("__str__", &mgpu::Tiling::ToString) + .def("__repr__", &mgpu::Tiling::ToString) + .def(nb::self == nb::self) + .def("__hash__", [](const mgpu::Tiling& self) { + return absl::Hash{}(self); + }); } diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index b7359495f184..3b93ff271f4c 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -405,3 +405,19 @@ cc_library( name = "library_paths", hdrs = ["library_paths.h"], ) + +cc_library( + name = "tiled_layout", + srcs = ["tiled_layout.cc"], + hdrs = ["tiled_layout.h"], + deps = ["@com_google_absl//absl/log:check"], +) + +cc_test( + name = "tiled_layout_test", + srcs = ["tiled_layout_test.cc"], + deps = [ + ":tiled_layout", + "//testing/base/public:gunit_main", + ], +) diff --git a/jaxlib/mosaic/gpu/tiled_layout.cc b/jaxlib/mosaic/gpu/tiled_layout.cc new file mode 100644 index 000000000000..1bd24ab5f747 --- /dev/null +++ b/jaxlib/mosaic/gpu/tiled_layout.cc @@ -0,0 +1,350 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/tiled_layout.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" + +namespace jax::mosaic::gpu { + +Tiling::Tiling(std::vector tiles) : tiles_(std::move(tiles)) { + if (tiles_.empty()) return; + size_t last_tile_rank = std::numeric_limits::max(); + for (const Tile& tile : tiles_) { + CHECK(tile.size() <= last_tile_rank) << "Tiles must have a decreasing rank"; + CHECK(!tile.empty()) << "Tiles must not be empty"; + for (int64_t d : tile) { + CHECK(d > 0) << "Tile shape must only have positive sizes"; + } + last_tile_rank = tile.size(); + } +} + +std::vector Tiling::TileShape( + const std::vector& shape) const { + std::vector current_shape = shape; + for (const Tile& tile : tiles_) { + CHECK(tile.size() <= current_shape.size()) + << "Tiling does not apply to shape"; + size_t untiled_rank = current_shape.size() - tile.size(); + std::vector next_shape; + next_shape.reserve(untiled_rank + 2 * tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_shape.push_back(current_shape[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + int64_t dim = current_shape[untiled_rank + i]; + int64_t t = tile[i]; + CHECK(dim % t == 0) << "Dimension not divisible by tile size"; + next_shape.push_back(dim / t); + } + for (int64_t t : tile) { + next_shape.push_back(t); + } + current_shape = std::move(next_shape); + } + return current_shape; +} + +std::vector Tiling::UntileShape( + const std::vector& shape) const { + std::vector current_shape = shape; + for (auto it = tiles_.rbegin(); it != tiles_.rend(); ++it) { + const Tile& tile = *it; + CHECK(tile.size() * 2 <= current_shape.size()) << "Invalid tiled shape"; + size_t untiled_rank = current_shape.size() - 2 * tile.size(); + std::vector next_shape; + next_shape.reserve(untiled_rank + tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_shape.push_back(current_shape[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + int64_t outer = current_shape[untiled_rank + i]; + int64_t inner = current_shape[untiled_rank + tile.size() + i]; + CHECK(inner == tile[i]) << "Tiling dimension mismatch"; + next_shape.push_back(outer * inner); + } + current_shape = std::move(next_shape); + } + return current_shape; +} + +std::vector Tiling::TileStrides( + const std::vector& strides) const { + std::vector current_strides = strides; + for (const Tile& tile : tiles_) { + size_t untiled_rank = current_strides.size() - tile.size(); + std::vector next_strides; + next_strides.reserve(untiled_rank + 2 * tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_strides.push_back(current_strides[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_strides.push_back(current_strides[untiled_rank + i] * tile[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_strides.push_back(current_strides[untiled_rank + i]); + } + current_strides = std::move(next_strides); + } + return current_strides; +} + +std::vector Tiling::TileIndices( + const std::vector& indices) const { + std::vector current_indices = indices; + for (const Tile& tile : tiles_) { + size_t untiled_rank = current_indices.size() - tile.size(); + std::vector next_indices; + next_indices.reserve(untiled_rank + 2 * tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_indices.push_back(current_indices[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_indices.push_back(current_indices[untiled_rank + i] / tile[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_indices.push_back(current_indices[untiled_rank + i] % tile[i]); + } + current_indices = std::move(next_indices); + } + return current_indices; +} + +std::vector Tiling::UntileIndices( + const std::vector& indices) const { + std::vector current_indices = indices; + for (auto it = tiles_.rbegin(); it != tiles_.rend(); ++it) { + const Tile& tile = *it; + size_t untiled_rank = current_indices.size() - 2 * tile.size(); + std::vector next_indices; + next_indices.reserve(untiled_rank + tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_indices.push_back(current_indices[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + int64_t outer = current_indices[untiled_rank + i]; + int64_t inner = current_indices[untiled_rank + tile.size() + i]; + next_indices.push_back(outer * tile[i] + inner); + } + current_indices = std::move(next_indices); + } + return current_indices; +} + +std::pair>, std::vector>> +Tiling::TileNestedShapeStrides( + const std::vector>& shape, + const std::vector>& strides) const { + CHECK(shape.size() == strides.size()) + << "Shape and strides must have the same length"; + std::vector> current_shape = shape; + std::vector> current_strides = strides; + + for (const Tile& tile : tiles_) { + CHECK(tile.size() <= current_shape.size()) + << "Tiling does not apply to shape"; + size_t untiled_rank = current_shape.size() - tile.size(); + std::vector> next_shape; + std::vector> next_strides; + next_shape.reserve(untiled_rank + 2 * tile.size()); + next_strides.reserve(untiled_rank + 2 * tile.size()); + + for (size_t i = 0; i < untiled_rank; ++i) { + next_shape.push_back(current_shape[i]); + next_strides.push_back(current_strides[i]); + } + + std::vector> major_dim_shapes; + std::vector> minor_dim_shapes; + std::vector> major_dim_strides; + std::vector> minor_dim_strides; + + for (size_t i = 0; i < tile.size(); ++i) { + int64_t t = tile[i]; + const std::vector& dim_shape = current_shape[untiled_rank + i]; + const std::vector& dim_strides = + current_strides[untiled_rank + i]; + + std::vector major_dim_shape_rev, major_dim_stride_rev; + std::vector minor_dim_shape_rev, minor_dim_stride_rev; + + for (size_t j = 0; j < dim_shape.size(); ++j) { + size_t idx = dim_shape.size() - 1 - j; + int64_t d = dim_shape[idx]; + int64_t s = dim_strides[idx]; + + if (d < t) { + CHECK(t % d == 0) << "Dimension not divisible by tile size"; + t /= d; + minor_dim_shape_rev.push_back(d); + minor_dim_stride_rev.push_back(s); + } else if (t != 1) { + CHECK(d % t == 0) << "Dimension not divisible by tile size"; + minor_dim_shape_rev.push_back(t); + minor_dim_stride_rev.push_back(s); + if (d != t) { + major_dim_shape_rev.push_back(d / t); + major_dim_stride_rev.push_back(s * t); + } + t = 1; + } else { + major_dim_shape_rev.push_back(d); + major_dim_stride_rev.push_back(s); + } + } + CHECK(t == 1) << "Tile size too large for dimension"; + + major_dim_shapes.push_back(std::vector( + major_dim_shape_rev.rbegin(), major_dim_shape_rev.rend())); + major_dim_strides.push_back(std::vector( + major_dim_stride_rev.rbegin(), major_dim_stride_rev.rend())); + minor_dim_shapes.push_back(std::vector( + minor_dim_shape_rev.rbegin(), minor_dim_shape_rev.rend())); + minor_dim_strides.push_back(std::vector( + minor_dim_stride_rev.rbegin(), minor_dim_stride_rev.rend())); + } + next_shape.insert(next_shape.end(), major_dim_shapes.begin(), + major_dim_shapes.end()); + next_shape.insert(next_shape.end(), minor_dim_shapes.begin(), + minor_dim_shapes.end()); + next_strides.insert(next_strides.end(), major_dim_strides.begin(), + major_dim_strides.end()); + next_strides.insert(next_strides.end(), minor_dim_strides.begin(), + minor_dim_strides.end()); + current_shape = std::move(next_shape); + current_strides = std::move(next_strides); + } + + auto normalize = [](std::vector>& v) { + for (std::vector& d : v) { + if (d.empty()) { + d.push_back(1); + } + } + }; + normalize(current_shape); + normalize(current_strides); + + return {std::move(current_shape), std::move(current_strides)}; +} + +std::vector Tiling::TileDimension(int dim) const { + size_t tiling_rank = tiles_[0].size(); + CHECK(dim >= 0 && dim < tiling_rank) << "Invalid dimension"; + std::vector strides(tiling_rank, 1); + strides[dim] = 0; + std::vector tiled_strides = TileStrides(strides); + std::vector result; + result.reserve(tiled_strides.size()); + for (int64_t s : tiled_strides) { + result.push_back(s == 0); + } + return result; +} + +Tiling Tiling::RemoveDimension(int dim) const { + size_t tiling_rank = tiles_[0].size(); + CHECK(dim >= 0 && dim < tiling_rank) << "Invalid dimension"; + int dim_in_tile = dim; + std::vector new_tiles; + size_t last_tile_rank = tiling_rank; + for (Tile t : tiles_) { + CHECK(last_tile_rank >= t.size()) << "Rank invariant violated"; + dim_in_tile -= (last_tile_rank - t.size()); + last_tile_rank = t.size(); + if (dim_in_tile >= 0) { + t.erase(t.begin() + dim_in_tile); + } + if (t.empty()) break; + new_tiles.push_back(std::move(t)); + } + return Tiling(std::move(new_tiles)); +} + +Tiling Tiling::Canonicalize() const { + if (tiles_.size() <= 1) return *this; + std::vector new_tiles; + new_tiles.push_back(tiles_[0]); + Tile shape = tiles_[0]; + for (size_t i = 1; i < tiles_.size(); ++i) { + const Tile& tile = tiles_[i]; + Tile canonical_tile; + bool found_non_one = false; + for (size_t j = 0; j < tile.size(); ++j) { + if (tile[j] != 1) { + canonical_tile.assign(tile.begin() + j, tile.end()); + found_non_one = true; + break; + } + } + if (!found_non_one) { + canonical_tile = {1}; + } + + bool redundant = true; + if (shape.size() < canonical_tile.size()) { + redundant = false; + } else { + for (size_t k = 0; k < canonical_tile.size(); ++k) { + if (shape[shape.size() - canonical_tile.size() + k] != + canonical_tile[k]) { + redundant = false; + break; + } + } + } + + if (redundant) continue; + shape = canonical_tile; + new_tiles.push_back(std::move(canonical_tile)); + } + return Tiling(std::move(new_tiles)); +} + +std::string Tiling::ToString() const { + std::stringstream ss; + ss << "Tiling("; + for (const Tile& tile : tiles_) { + ss << "("; + for (size_t i = 0; i < tile.size(); ++i) { + if (i > 0) ss << ", "; + ss << tile[i]; + } + if (tile.size() == 1) ss << ","; + ss << ")"; + } + ss << ")"; + return ss.str(); +} + +bool Tiling::operator==(const Tiling& other) const { + return tiles_ == other.tiles_; +} + +std::ostream& operator<<(std::ostream& os, const Tiling& tiling) { + return os << tiling.ToString(); +} + +} // namespace jax::mosaic::gpu diff --git a/jaxlib/mosaic/gpu/tiled_layout.h b/jaxlib/mosaic/gpu/tiled_layout.h new file mode 100644 index 000000000000..9b22a9583d16 --- /dev/null +++ b/jaxlib/mosaic/gpu/tiled_layout.h @@ -0,0 +1,111 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_EXPERIMENTAL_MOSAIC_GPU_CC_TILED_LAYOUT_H_ +#define THIRD_PARTY_PY_JAX_EXPERIMENTAL_MOSAIC_GPU_CC_TILED_LAYOUT_H_ + +#include +#include +#include +#include + +namespace jax::mosaic::gpu { + +// A tiling expression describing a permutation of elements of an nd-array. +// +// To apply one level of tiling to an array, each of the trailing dimensions (up +// to the rank of the tile) is unfolded into two dimensions: first equal to the +// ratio of the dimension size and the tile size, and second equal to the tile +// size. Then, all newly unfolded minor dimensions are transposed to appear at +// the end. +// +// This expression describes multi-level tiling, by applying each element of +// `tiles` in sequence to the array. +// +// See https://openxla.org/xla/tiled_layout for a more detailed explanation. +class Tiling { + public: + using Tile = std::vector; + + explicit Tiling(std::vector tiles); + + bool operator==(const Tiling& other) const; + bool operator!=(const Tiling& other) const { return !(*this == other); } + + const std::vector& tiles() const { return tiles_; } + + // Compute the shape of an array after tiling. + std::vector TileShape(const std::vector& shape) const; + + // Compute the shape of an array before tiling from its tiled shape. + std::vector UntileShape(const std::vector& shape) const; + + // Compute the strides of an array after tiling. + std::vector TileStrides(const std::vector& strides) const; + + // Compute the indices of an array after tiling. + std::vector TileIndices(const std::vector& indices) const; + + // Compute the indices of an array before tiling from its tiled indices. + std::vector UntileIndices(const std::vector& indices) const; + + // A fused version of `TileShape` and `TileStrides` for nested shapes. + // + // By nested shape we mean that each logical dimension (i.e. each element of + // shape/strides) is actually composed out of multiple physical dimensions. + // For example, a row-major array of logical shape (128, 128) that is tiled + // into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each + // dim is split into two sub-dims) and nested strides of + // ((2 * 64 * 64, 64), (64 * 64, 1)). + std::pair, std::vector> TileNestedShapeStrides( + const std::vector>& shape, + const std::vector>& strides) const; + + // Returns true if the tiled dim originated from the given input dim. + std::vector TileDimension(int dim) const; + + // Returns a tiling with the given dimension removed. + Tiling RemoveDimension(int dim) const; + + // We define a tiling to be canonical if, at each step (except the first one, + // which defines the base tile shape): + + // 1. The tiling partitions at least one dimension in more than 1 tile. For + // example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it + // yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which + // allows getting rid of the unnecessary `1` dimensions. + // 2. The leading dimensions of each tile are not `1`. If canonicalizing a + // tile in this way leads to an empty tile, then the tile is given shape + // `(1,)`---which is still a meaningful (final) tile. For example, the + // tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape + // `(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows + // getting rid of the unnecessary `1` dimension, and yields a shape + // `(8, 2, 4)`. + Tiling Canonicalize() const; + + std::string ToString() const; + + template + friend H AbslHashValue(H h, const Tiling& tiling) { + return H::combine(std::move(h), tiling.tiles_); + } + + private: + std::vector tiles_; +}; + +} // namespace jax::mosaic::gpu + +#endif // THIRD_PARTY_PY_JAX_EXPERIMENTAL_MOSAIC_GPU_CC_TILED_LAYOUT_H_ diff --git a/jaxlib/mosaic/gpu/tiled_layout_test.cc b/jaxlib/mosaic/gpu/tiled_layout_test.cc new file mode 100644 index 000000000000..145f8091d819 --- /dev/null +++ b/jaxlib/mosaic/gpu/tiled_layout_test.cc @@ -0,0 +1,111 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/tiled_layout.h" + +#include +#include + +#include +#include + +namespace jax::mosaic::gpu { +namespace { + +using ::testing::ElementsAre; + +TEST(TilingTest, TileNestedShapeStrides) { + Tiling tiling({{64, 64}}); + std::vector> shape = {{128}, {128}}; + std::vector> strides = {{128}, {1}}; + + auto [tiled_shape, tiled_strides] = + tiling.TileNestedShapeStrides(shape, strides); + + std::vector> expected_shape = {{2}, {2}, {64}, {64}}; + std::vector> expected_strides = { + {64 * 128}, {64}, {128}, {1}}; + EXPECT_EQ(tiled_shape, expected_shape); + EXPECT_EQ(tiled_strides, expected_strides); +} + +TEST(TilingTest, TileNestedShapeStridesAlreadySplit) { + Tiling tiling({{64, 64}}); + std::vector> shape = {{2, 64}, {2, 64}}; + std::vector> strides = {{64 * 128, 128}, {64, 1}}; + + auto [tiled_shape, tiled_strides] = + tiling.TileNestedShapeStrides(shape, strides); + + std::vector> expected_shape = {{2}, {2}, {64}, {64}}; + std::vector> expected_strides = { + {64 * 128}, {64}, {128}, {1}}; + EXPECT_EQ(tiled_shape, expected_shape); + EXPECT_EQ(tiled_strides, expected_strides); +} + +TEST(TilingTest, TileNestedShapeStridesMultiLevel) { + Tiling tiling({{64, 64}, {8}}); + std::vector> shape = {{128}, {128}}; + std::vector> strides = {{128}, {1}}; + + auto [tiled_shape, tiled_strides] = + tiling.TileNestedShapeStrides(shape, strides); + + std::vector> expected_shape = {{2}, {2}, {64}, {8}, {8}}; + std::vector> expected_strides = { + {8192}, {64}, {128}, {8}, {1}}; + EXPECT_EQ(tiled_shape, expected_shape); + EXPECT_EQ(tiled_strides, expected_strides); +} + +TEST(TilingTest, TileIndices) { + Tiling tiling({{64, 64}}); + std::vector indices = {70, 80}; + + std::vector tiled_indices = tiling.TileIndices(indices); + + EXPECT_THAT(tiled_indices, ElementsAre(1, 1, 6, 16)); +} + +TEST(TilingTest, UntileIndices) { + Tiling tiling({{64, 64}}); + std::vector indices = {1, 1, 6, 16}; + + std::vector untiled_indices = tiling.UntileIndices(indices); + + EXPECT_THAT(untiled_indices, ElementsAre(70, 80)); +} + +TEST(TilingTest, TileIndices_MultiLevel) { + Tiling tiling({{64, 64}, {8}}); + std::vector indices = {70, 80}; + + std::vector tiled_indices = tiling.TileIndices(indices); + + EXPECT_THAT(tiled_indices, ElementsAre(1, 1, 6, 2, 0)); +} + +TEST(TilingTest, UntileIndices_MultiLevel) { + Tiling tiling({{64, 64}, {8}}); + std::vector indices = {1, 1, 6, 2, 0}; + + auto untiled_indices = tiling.UntileIndices(indices); + + EXPECT_THAT(untiled_indices, ElementsAre(70, 80)); +} + +} // namespace +} // namespace jax::mosaic::gpu