Skip to content

Commit cc8ccfd

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add a test for remote TMA store-add
PiperOrigin-RevId: 829396816
1 parent bdc7cf0 commit cc8ccfd

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

tests/mosaic/gpu_test_distributed.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def kernel(ctx, dst, _):
8686
y_np = multihost_utils.process_allgather(y, tiled=True)
8787
np.testing.assert_array_equal(y_np, np.arange(jax.device_count()))
8888

89-
def test_remote_async_copy(self):
89+
def test_remote_async_copy_basic(self):
9090
i32 = ir.IntegerType.get_signless(32)
9191
def kernel(ctx, src, sem, dst, scratch):
9292
tmp, barrier = scratch
@@ -124,6 +124,48 @@ def kernel(ctx, src, sem, dst, scratch):
124124
y_np, np.concatenate(np.split(x_np, 2)[::-1], axis=0)
125125
)
126126

127+
def test_remote_async_copy_add(self):
128+
i32 = ir.IntegerType.get_signless(32)
129+
def kernel(ctx, src, sem, dst, scratch):
130+
tmp, barrier = scratch
131+
other_device = arith.subi(arith.constant(i32, 1), ctx.device_id())
132+
other_sem = mgpu.SemaphoreRef(
133+
mgpu.utils.memref_ptr(ctx.to_remote(sem, other_device))
134+
)
135+
my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem))
136+
ctx.async_copy(src_ref=src, dst_ref=tmp, barrier=barrier)
137+
barrier.wait()
138+
fa.FragmentedArray.splat(arith.constant(ir.F32Type.get(), 1.0), (32, 64)).store_untiled(dst)
139+
mgpu.warpgroup_barrier()
140+
other_sem.signal(1)
141+
my_sem.wait(1)
142+
ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_peer_id=other_device, reduction_op="add")
143+
ctx.await_async_copy(0)
144+
other_sem.signal(1)
145+
my_sem.wait(1)
146+
147+
mesh = jax.make_mesh(
148+
(2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
149+
)
150+
with jax.set_mesh(mesh):
151+
x_np = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64)
152+
x = jax.sharding.reshard(x_np, P("x"))
153+
sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P())
154+
y, _ = jax.jit(
155+
jax.shard_map(
156+
lambda x, sem: mgpu.as_gpu_kernel(
157+
kernel, (1, 1, 1), (128, 1, 1), x, x, (x, mgpu.TMABarrier()), inout_shape=sem
158+
)(x, sem),
159+
in_specs=(P("x"), P(None)),
160+
out_specs=[P("x"), P(None)],
161+
check_vma=False,
162+
)
163+
)(x, sem)
164+
y_np = multihost_utils.process_allgather(y, tiled=True)
165+
np.testing.assert_array_equal(
166+
y_np, 1 + np.concatenate(np.split(x_np, 2)[::-1], axis=0)
167+
)
168+
127169
def test_remote_semaphore(self):
128170
i32 = ir.IntegerType.get_signless(32)
129171
def kernel(ctx, sem, _):

0 commit comments

Comments
 (0)