Skip to content

Commit 238f05d

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Support not tiled 3D+ transposed in swap LANE lowering rule.
PiperOrigin-RevId: 844713064
1 parent 164d8bb commit 238f05d

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,12 +1753,16 @@ def _swap_lowering_rule(
17531753
layout=value.layout,
17541754
)
17551755
value.store_tiled(x_smem, swizzle=swizzle)
1756-
case () | (gpu_core.TransposeRef((1, 0)),):
1756+
case () | (gpu_core.TransposeRef(),):
17571757
transposed = bool(transforms)
17581758
match value.layout:
17591759
case mgpu.TiledLayout():
17601760
if transposed:
1761-
x_smem = mgpu.memref_transpose(x_smem, (1, 0))
1761+
assert isinstance(
1762+
transforms[0], gpu_core.TransposeRef
1763+
) # silence pytype
1764+
permutation = transforms[0].permutation
1765+
x_smem = mgpu.memref_transpose(x_smem, permutation)
17621766
old_value = mgpu.FragmentedArray.load_untiled(
17631767
x_smem,
17641768
layout=value.layout,

tests/pallas/mosaic_gpu_test.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,31 +1078,31 @@ def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref):
10781078
idx = jax.random.permutation(jax.random.key(1234), out_shape[0]).astype(jnp.uint32)
10791079
np.testing.assert_array_equal(kernel(x, idx), x[idx, 64:])
10801080

1081-
@parameterized.parameters(
1082-
(plgpu.Layout.WGMMA, plgpu.Layout.WGMMA_TRANSPOSED),
1083-
(plgpu.Layout.WGMMA_TRANSPOSED, plgpu.Layout.WGMMA),
1081+
@parameterized.product(
1082+
src_transposed=(False, True), shape=((128, 128), (1, 128, 128))
10841083
)
1085-
def test_transposed_load_store(self, src_layout, dst_layout):
1086-
def is_transposed(layout):
1087-
return layout == plgpu.Layout.WGMMA_TRANSPOSED
1088-
1089-
shape, dtype = (128, 128), jnp.float32
1090-
1084+
def test_transposed_load_store(self, src_transposed, shape):
1085+
dtype = jnp.float32
1086+
permutation = (0, 2, 1) if len(shape) == 3 else (1, 0)
10911087
@functools.partial(
10921088
self.kernel,
10931089
out_shape=jax.ShapeDtypeStruct(shape, dtype),
10941090
)
10951091
def kernel(src_ref, dst_ref):
1096-
if is_transposed(src_layout):
1097-
src_ref = src_ref.T
1098-
if is_transposed(dst_layout):
1099-
dst_ref = dst_ref.T
1092+
if src_transposed:
1093+
src_ref = plgpu.transpose_ref(src_ref, permutation)
1094+
src_layout = plgpu.Layout.WGMMA_TRANSPOSED
1095+
dst_layout = plgpu.Layout.WGMMA
1096+
else:
1097+
dst_ref = plgpu.transpose_ref(dst_ref, permutation)
1098+
src_layout = plgpu.Layout.WGMMA
1099+
dst_layout = plgpu.Layout.WGMMA_TRANSPOSED
11001100
src = plgpu.load(src_ref, (), layout=src_layout, optimized=False)
11011101
dst = plgpu.layout_cast(src, dst_layout)
11021102
dst_ref[...] = dst
11031103

11041104
x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape)
1105-
np.testing.assert_array_equal(kernel(x), x.T)
1105+
np.testing.assert_array_equal(kernel(x), jnp.transpose(x, permutation))
11061106

11071107
@parameterized.product(
11081108
src_memory_space=[plgpu.SMEM, plgpu.GMEM],

0 commit comments

Comments
 (0)