Skip to content

Commit dd12a48

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for arbitrary reshapes of contiguous refs
PiperOrigin-RevId: 840650264
1 parent 92621c7 commit dd12a48

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

jax/experimental/mosaic/gpu/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,18 @@ def memref_reshape(
722722
(), ref_ty.element_type, new_layout, ref_ty.memory_space
723723
)
724724
return memref.collapse_shape(result_ty, ref, [])
725+
# For contiguous refs we can do arbitrary reshapes easily.
726+
strides, _ = ref_ty.get_strides_and_offset()
727+
if all(
728+
d == 1 or s1 == s2
729+
for d, s1, s2 in zip(
730+
ref_ty.shape,
731+
get_contiguous_strides(ref_ty.shape),
732+
strides,
733+
strict=True,
734+
)
735+
):
736+
return memref_unfold(memref_fold(ref, 0, ref_ty.rank), 0, shape)
725737
return _reshape(ref, src_shape, dst_shape)
726738

727739

tests/mosaic/gpu_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def kernel(ctx, inp, out, _):
394394
("un", (1, 10, 1), (1, 5, 2, 1,)),
395395
("to_scalar", (1, 1, 1), ()),
396396
("from_scalar", (), (1, 1, 1)),
397+
("arbitrary", (2 * 5, 7 * 3), (2, 7, 5, 3)),
397398
)
398399
def test_reshape(self, inp_shape, out_shape):
399400
def kernel(ctx, inp, out, _):

0 commit comments

Comments
 (0)