Skip to content

Commit 7a007ea

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Pallas/Mosaic GPU] Add support for s8 WGMMA with lhs in registers.
The low-level support only allows `swizzle=64` or larger for the time being. PiperOrigin-RevId: 826554680
1 parent 14d89b5 commit 7a007ea

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,7 @@ def to_mgpu(self) -> mgpu.FragmentedLayout:
14371437
class Layout(SomeLayout, enum.Enum):
14381438
#: [m, n] matrix, where m % 64 == 0 == n % 8.
14391439
WGMMA = enum.auto()
1440+
WGMMA_8BIT = enum.auto()
14401441
WGMMA_UPCAST_2X = enum.auto()
14411442
WGMMA_UPCAST_4X = enum.auto()
14421443
WGMMA_TRANSPOSED = enum.auto()
@@ -1472,6 +1473,9 @@ def check_no_args():
14721473
case Layout.WGMMA:
14731474
check_no_args()
14741475
return mgpu.WGMMA_LAYOUT
1476+
case Layout.WGMMA_8BIT:
1477+
check_no_args()
1478+
return mgpu.WGMMA_LAYOUT_8BIT
14751479
case Layout.WGMMA_UPCAST_2X:
14761480
check_no_args()
14771481
return mgpu.WGMMA_LAYOUT_UPCAST_2X

jax/experimental/mosaic/gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
TCGEN05_COL_LAYOUT as TCGEN05_COL_LAYOUT,
6464
TiledLayout as TiledLayout,
6565
WGMMA_LAYOUT as WGMMA_LAYOUT,
66+
WGMMA_LAYOUT_8BIT as WGMMA_LAYOUT_8BIT,
6667
WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT,
6768
WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT,
6869
WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT,

tests/pallas/mosaic_gpu_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,6 +2997,39 @@ def scope(acc_ref):
29972997
)(a, b)
29982998
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
29992999

3000+
def test_wgmma_registers_integer(self):
3001+
# TODO(bchetioui): plumb in is_signed into WGMMA lowering and allow an
3002+
# integer accumulator type to be created.
3003+
self.skip_if_wg_semantics()
3004+
input_dtype = jnp.int8
3005+
out_dtype = jnp.int32
3006+
def kernel(a_ref, b_ref, o_ref):
3007+
def scope(acc_ref):
3008+
a_regs = plgpu.load(a_ref, (), layout=plgpu.Layout.WGMMA_8BIT)
3009+
plgpu.wgmma(acc_ref, a_regs, plgpu.transpose_ref(b_ref, (1, 0)))
3010+
return acc_ref[...]
3011+
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), out_dtype))
3012+
3013+
key1, key2 = jax.random.split(jax.random.key(42), 2)
3014+
m = 64
3015+
k = 128
3016+
n = 192
3017+
a = jax.random.randint(key1, shape=(m, k), minval=-128, maxval=127, dtype=input_dtype)
3018+
b = jax.random.randint(key2, shape=(n, k), minval=-128, maxval=127, dtype=input_dtype)
3019+
3020+
transforms = self.default_transforms(swizzle=64, dtype=input_dtype)
3021+
res = self.pallas_call(
3022+
kernel,
3023+
in_specs=[
3024+
plgpu.BlockSpec(transforms=transforms),
3025+
plgpu.BlockSpec(transforms=transforms),
3026+
],
3027+
out_shape=jax.ShapeDtypeStruct((64, 192), out_dtype),
3028+
)(a, b)
3029+
np.testing.assert_array_equal(
3030+
res, a.astype(out_dtype) @ b.T.astype(out_dtype)
3031+
)
3032+
30003033
def test_wgmma_registers_init(self):
30013034
def kernel(a_ref, b_ref, i_ref, o_ref):
30023035
def scope(acc_ref):

0 commit comments

Comments
 (0)