Skip to content

Commit fcb3cc7

Browse files
[XLA:MGPU] Port Tiling to C++.
PiperOrigin-RevId: 834710515
1 parent dd12a48 commit fcb3cc7

File tree

9 files changed

+744
-213
lines changed

9 files changed

+744
-213
lines changed

jax/experimental/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ py_library_providing_imports_info(
265265
"//jax/_src:stages",
266266
"//jax/_src:util",
267267
"//jax/_src/lib",
268+
"//jax/experimental/mosaic/gpu/cc:ext",
268269
"//jaxlib/mlir:arithmetic_dialect",
269270
"//jaxlib/mlir:builtin_dialect",
270271
"//jaxlib/mlir:control_flow_dialect",
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
load("@rules_cc//cc:cc_library.bzl", "cc_library")
17+
load("@rules_cc//cc:cc_test.bzl", "cc_test")
18+
load("//jaxlib:jax.bzl", "nanobind_extension")
19+
20+
package(
21+
default_applicable_licenses = [],
22+
default_visibility = ["//visibility:private"],
23+
)
24+
25+
cc_library(
26+
name = "tiled_layout",
27+
srcs = ["tiled_layout.cc"],
28+
hdrs = ["tiled_layout.h"],
29+
deps = ["@com_google_absl//absl/log:check"],
30+
)
31+
32+
cc_test(
33+
name = "tiled_layout_test",
34+
srcs = ["tiled_layout_test.cc"],
35+
deps = [
36+
":tiled_layout",
37+
"//testing/base/public:gunit_main",
38+
],
39+
)
40+
41+
nanobind_extension(
42+
name = "ext",
43+
srcs = ["ext.cc"],
44+
visibility = ["//jax:__subpackages__"],
45+
deps = [
46+
":tiled_layout",
47+
"@com_google_absl//absl/hash",
48+
"@nanobind",
49+
],
50+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/* Copyright 2025 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cstdint>
17+
#include <vector>
18+
19+
#include "absl/hash/hash.h"
20+
#include "nanobind/nanobind.h"
21+
#include "nanobind/operators.h"
22+
#include "nanobind/stl/pair.h"
23+
#include "nanobind/stl/string.h"
24+
#include "nanobind/stl/vector.h"
25+
#include "third_party/py/jax/experimental/mosaic/gpu/cc/tiled_layout.h"
26+
27+
namespace nb = nanobind;
28+
namespace mgpu = jax::mosaic::gpu;
29+
30+
NB_MODULE(ext, m) {
31+
nb::class_<mgpu::Tiling>(m, "Tiling")
32+
.def(nb::init<std::vector<std::vector<int64_t>>>(), nb::arg("tiles"))
33+
.def(
34+
"tile_shape",
35+
[](const mgpu::Tiling& self, const std::vector<int64_t>& shape) {
36+
return nb::tuple(nb::cast(self.TileShape(shape)));
37+
},
38+
nb::arg("shape"))
39+
.def(
40+
"untile_shape",
41+
[](const mgpu::Tiling& self, const std::vector<int64_t>& shape) {
42+
return nb::tuple(nb::cast(self.UntileShape(shape)));
43+
},
44+
nb::arg("shape"))
45+
.def(
46+
"tile_strides",
47+
[](const mgpu::Tiling& self, const std::vector<int64_t>& strides) {
48+
return nb::tuple(nb::cast(self.TileStrides(strides)));
49+
},
50+
nb::arg("strides"))
51+
.def(
52+
"tile_indices",
53+
[](const mgpu::Tiling& self, const std::vector<int64_t>& indices) {
54+
return nb::tuple(nb::cast(self.TileIndices(indices)));
55+
},
56+
nb::arg("indices"))
57+
.def(
58+
"untile_indices",
59+
[](const mgpu::Tiling& self, const std::vector<int64_t>& indices) {
60+
return nb::tuple(nb::cast(self.UntileIndices(indices)));
61+
},
62+
nb::arg("indices"))
63+
.def(
64+
"tile_nested_shape_strides",
65+
[](const mgpu::Tiling& self,
66+
const std::vector<std::vector<int64_t>>& shape,
67+
const std::vector<std::vector<int64_t>>& strides) {
68+
auto [tiled_shape, tiled_strides] =
69+
self.TileNestedShapeStrides(shape, strides);
70+
nb::list shape_list;
71+
for (const auto& s : tiled_shape) {
72+
shape_list.append(nb::tuple(nb::cast(s)));
73+
}
74+
nb::list strides_list;
75+
for (const auto& s : tiled_strides) {
76+
strides_list.append(nb::tuple(nb::cast(s)));
77+
}
78+
return nb::make_tuple(nb::tuple(shape_list),
79+
nb::tuple(strides_list));
80+
},
81+
nb::arg("shape"), nb::arg("strides"))
82+
.def(
83+
"tile_dimension",
84+
[](const mgpu::Tiling& self, int64_t dim) {
85+
return nb::tuple(nb::cast(self.TileDimension(dim)));
86+
},
87+
nb::arg("dim"))
88+
.def("remove_dimension", &mgpu::Tiling::RemoveDimension, nb::arg("dim"))
89+
.def("canonicalize", &mgpu::Tiling::Canonicalize)
90+
.def_prop_ro("tiles",
91+
[](const mgpu::Tiling& self) {
92+
nb::list tiles_list;
93+
for (const mgpu::Tiling::Tile& tile : self.tiles()) {
94+
tiles_list.append(nb::tuple(nb::cast(tile)));
95+
}
96+
return nb::tuple(tiles_list);
97+
})
98+
.def("__str__", &mgpu::Tiling::ToString)
99+
.def("__repr__", &mgpu::Tiling::ToString)
100+
.def(nb::self == nb::self)
101+
.def("__hash__", [](const mgpu::Tiling& self) {
102+
return absl::Hash<mgpu::Tiling>{}(self);
103+
});
104+
}

0 commit comments

Comments
 (0)