|
46 | 46 | c = utils.c |
47 | 47 |
|
48 | 48 |
|
49 | | -@dataclasses.dataclass(frozen=True) |
50 | | -class Tiling: |
51 | | - """A tiling expression describing a permutation of elements of an nd-array. |
52 | | -
|
53 | | - To apply one level of tiling to an array, each of the trailing dimensions (up |
54 | | - to the rank of the tile) is unfolded into two dimensions: first equal to the |
55 | | - ratio of the dimension size and the tile size, and second equal to the tile |
56 | | - size. Then, all newly unfolded minor dimensions are transposed to appear at |
57 | | - the end. |
58 | | -
|
59 | | - This expression describes multi-level tiling, by applying each element of |
60 | | - `tiles` in sequence to the array. |
61 | | -
|
62 | | - See https://openxla.org/xla/tiled_layout for a more detailed explanation. |
63 | | - """ |
64 | | - tiles: tuple[tuple[int, ...], ...] |
65 | | - |
66 | | - def __post_init__(self): |
67 | | - if not self.tiles: |
68 | | - return |
69 | | - last_tile_rank = len(self.tiles[0]) |
70 | | - for tile in self.tiles: |
71 | | - if len(tile) > last_tile_rank: |
72 | | - raise ValueError("Tiles must have a decreasing rank") |
73 | | - if not tile: |
74 | | - raise ValueError("Tiles must not be empty") |
75 | | - if any(d <= 0 for d in tile): |
76 | | - raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") |
77 | | - last_tile_rank = len(tile) |
78 | | - |
79 | | - def __str__(self): |
80 | | - return f"Tiling({''.join(map(str, self.tiles))})" |
81 | | - |
82 | | - def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: |
83 | | - """Computes the shape of an array after tiling.""" |
84 | | - orig_shape = shape |
85 | | - def fail(): |
86 | | - raise ValueError(f"Tiling {self.tiles} does not apply to shape {orig_shape}") |
87 | | - for tile in self.tiles: |
88 | | - if len(tile) > len(shape): |
89 | | - fail() |
90 | | - untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):] |
91 | | - if any(s % t != 0 for s, t in zip(tiled_dims, tile)): |
92 | | - fail() |
93 | | - shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile) |
94 | | - return shape |
95 | | - |
96 | | - def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: |
97 | | - """Computes the shape of an array before tiling from its tiled shape.""" |
98 | | - orig_shape = shape |
99 | | - def fail(): |
100 | | - raise ValueError( |
101 | | - f"shape {orig_shape} is not a valid result of applying tiling {self}." |
102 | | - ) |
103 | | - for tile in reversed(self.tiles): |
104 | | - if len(tile) > len(shape): |
105 | | - fail() |
106 | | - untiled_dims = shape[:-2 * len(tile)] |
107 | | - tiled_dims = shape[-2 * len(tile):-len(tile)] |
108 | | - tiling_dims = shape[-len(tile):] |
109 | | - if tiling_dims != tile: |
110 | | - fail() |
111 | | - shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) |
112 | | - return shape |
113 | | - |
114 | | - def canonicalize(self) -> Tiling: |
115 | | - """Returns a canonicalized version of the tiling. |
116 | | -
|
117 | | - We define a tiling to be canonical if, at each step (except the first one, |
118 | | - which defines the base tile shape): |
119 | | -
|
120 | | - 1. The tiling partitions at least one dimension in more than 1 tile. For |
121 | | - example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it |
122 | | - yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which |
123 | | - allows getting rid of the unnecessary `1` dimensions. |
124 | | - 2. The leading dimensions of each tile are not `1`. If canonicalizing a |
125 | | - tile in this way leads to an empty tile, then the tile is given shape |
126 | | - `(1,)`---which is still a meaningful (final) tile. For example, the |
127 | | - tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape |
128 | | - `(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows |
129 | | - getting rid of the unnecessary `1` dimension, and yields a shape |
130 | | - `(8, 2, 4)`. |
131 | | - """ |
132 | | - if len(self.tiles) <= 1: |
133 | | - return self |
| 49 | +Tiling = mgpu.dialect.Tiling |
134 | 50 |
|
135 | | - shape = self.tiles[0] |
136 | | - new_tiling = [self.tiles[0]] |
137 | | - for tile in self.tiles[1:]: |
138 | | - for i, d in enumerate(tile): |
139 | | - if d != 1: |
140 | | - canonical_tile = tile[i:] |
141 | | - break |
142 | | - else: |
143 | | - canonical_tile = (1,) |
144 | | - tiled_dims = shape[-len(canonical_tile):] |
145 | | - if tiled_dims == canonical_tile: |
146 | | - continue |
147 | | - shape = canonical_tile |
148 | | - new_tiling.append(canonical_tile) |
149 | | - return Tiling(tuple(new_tiling)) |
150 | | - |
151 | | - def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: |
152 | | - """Computes the strides of an array after tiling.""" |
153 | | - for tile in self.tiles: |
154 | | - untiled, tiled = strides[:-len(tile)], strides[-len(tile):] |
155 | | - strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) |
156 | | - return strides |
157 | | - |
158 | | - def tile_dimension(self, dim: int) -> tuple[bool, ...]: |
159 | | - """Result is True whenever the tiled dim originated from the given input dim.""" |
160 | | - tiling_rank = len(self.tiles[0]) |
161 | | - if dim < 0 or dim >= tiling_rank: |
162 | | - raise ValueError(f"Invalid dimension {dim} for tiling {self}") |
163 | | - strides = [1] * tiling_rank |
164 | | - strides[dim] = 0 |
165 | | - return tuple(s == 0 for s in self.tile_strides(tuple(strides))) |
166 | | - |
167 | | - def remove_dimension(self, dim: int) -> Tiling: |
168 | | - """Returns a tiling with the given dimension removed.""" |
169 | | - tiling_rank = len(self.tiles[0]) |
170 | | - if dim < 0 or dim >= tiling_rank: |
171 | | - raise ValueError(f"Invalid dimension {dim} for tiling {self}") |
172 | | - dim_in_tile = dim |
173 | | - tiles = [] |
174 | | - last_tile_rank = len(self.tiles[0]) |
175 | | - for t in self.tiles: |
176 | | - assert last_tile_rank >= len(t) |
177 | | - dim_in_tile -= last_tile_rank - len(t) |
178 | | - last_tile_rank = len(t) |
179 | | - if dim_in_tile >= 0: |
180 | | - t = t[:dim_in_tile] + t[dim_in_tile + 1:] |
181 | | - if not t: # If this tile is empty, all other tiles will be empty too. |
182 | | - break |
183 | | - tiles.append(t) |
184 | | - return Tiling(tuple(tiles)) |
185 | | - |
186 | | - def tile_nested_shape_strides( |
187 | | - self, |
188 | | - shape: tuple[tuple[int, ...], ...], |
189 | | - strides: tuple[tuple[int, ...], ...], |
190 | | - ) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]: |
191 | | - """A fused version of `tile_shape` and `tile_strides` for nested shapes. |
192 | | -
|
193 | | - By nested shape we mean that each logical dimension (i.e. each element of |
194 | | - shape/strides) is actually composed out of multiple physical dimensions. |
195 | | - For example, a row-major array of logical shape (128, 128) that is tiled |
196 | | - into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each |
197 | | - dim is split into two sub-dims) and nested strides of |
198 | | - ((2 * 64 * 64, 64), (64 * 64, 1)). |
199 | | - """ |
200 | | - if len(shape) != len(strides): |
201 | | - raise ValueError( |
202 | | - f"Shape {shape} and strides {strides} must have the same length" |
203 | | - ) |
204 | | - def fail_if(cond, shape=shape): # Capture shape now. |
205 | | - if cond: |
206 | | - raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") |
207 | | - for tile in self.tiles: |
208 | | - fail_if(len(tile) > len(shape)) |
209 | | - untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):] |
210 | | - untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):] |
211 | | - major_dim_shapes, major_dim_strides = [], [] |
212 | | - minor_dim_shapes, minor_dim_strides = [], [] |
213 | | - for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides): |
214 | | - major_dim_shape_rev, major_dim_stride_rev = [], [] |
215 | | - minor_dim_shape_rev, minor_dim_stride_rev = [], [] |
216 | | - for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True): |
217 | | - if d < t: # We will need to tile more dims |
218 | | - fail_if(t % d != 0) |
219 | | - t //= d |
220 | | - minor_dim_shape_rev.append(d) |
221 | | - minor_dim_stride_rev.append(s) |
222 | | - elif t != 1: # Last dim to tile! |
223 | | - fail_if(d % t != 0) |
224 | | - minor_dim_shape_rev.append(t) |
225 | | - minor_dim_stride_rev.append(s) |
226 | | - if d != t: # No need to insert singleton dims. |
227 | | - major_dim_shape_rev.append(d // t) |
228 | | - major_dim_stride_rev.append(s * t) |
229 | | - t = 1 |
230 | | - else: # Done tiling! |
231 | | - major_dim_shape_rev.append(d) |
232 | | - major_dim_stride_rev.append(s) |
233 | | - fail_if(t != 1) |
234 | | - major_dim_shapes.append(major_dim_shape_rev[::-1]) |
235 | | - minor_dim_shapes.append(minor_dim_shape_rev[::-1]) |
236 | | - major_dim_strides.append(major_dim_stride_rev[::-1]) |
237 | | - minor_dim_strides.append(minor_dim_stride_rev[::-1]) |
238 | | - shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) # type: ignore[arg-type] |
239 | | - strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) # type: ignore[arg-type] |
240 | | - return ( |
241 | | - tuple(tuple(d) if d else (1,) for d in shape), |
242 | | - tuple(tuple(d) if d else (1,) for d in strides), |
243 | | - ) |
244 | | - |
245 | | - def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: |
246 | | - for tile in self.tiles: |
247 | | - untiled, tiled = indices[:-len(tile)], indices[-len(tile):] |
248 | | - indices = ( |
249 | | - *untiled, |
250 | | - *(i // t for i, t in zip(tiled, tile)), |
251 | | - *(i % t for i, t in zip(tiled, tile)), |
252 | | - ) |
253 | | - return indices |
254 | | - |
255 | | - def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: |
256 | | - for tile in reversed(self.tiles): |
257 | | - untiled = indices[:-2 * len(tile)] |
258 | | - outer = indices[-2 * len(tile):-len(tile)] |
259 | | - inner = indices[-len(tile):] |
260 | | - indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile))) |
261 | | - return indices |
262 | 51 |
|
263 | 52 | def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: |
264 | 53 | """Like built-in enumerate, but returns negative indices into the sequence.""" |
|
0 commit comments