@@ -86,7 +86,7 @@ def kernel(ctx, dst, _):
8686 y_np = multihost_utils .process_allgather (y , tiled = True )
8787 np .testing .assert_array_equal (y_np , np .arange (jax .device_count ()))
8888
89- def test_remote_async_copy (self ):
89+ def test_remote_async_copy_basic (self ):
9090 i32 = ir .IntegerType .get_signless (32 )
9191 def kernel (ctx , src , sem , dst , scratch ):
9292 tmp , barrier = scratch
@@ -124,6 +124,48 @@ def kernel(ctx, src, sem, dst, scratch):
124124 y_np , np .concatenate (np .split (x_np , 2 )[::- 1 ], axis = 0 )
125125 )
126126
127+ def test_remote_async_copy_add (self ):
128+ i32 = ir .IntegerType .get_signless (32 )
129+ def kernel (ctx , src , sem , dst , scratch ):
130+ tmp , barrier = scratch
131+ other_device = arith .subi (arith .constant (i32 , 1 ), ctx .device_id ())
132+ other_sem = mgpu .SemaphoreRef (
133+ mgpu .utils .memref_ptr (ctx .to_remote (sem , other_device ))
134+ )
135+ my_sem = mgpu .SemaphoreRef (mgpu .utils .memref_ptr (sem ))
136+ ctx .async_copy (src_ref = src , dst_ref = tmp , barrier = barrier )
137+ barrier .wait ()
138+ fa .FragmentedArray .splat (arith .constant (ir .F32Type .get (), 1.0 ), (32 , 64 )).store_untiled (dst )
139+ mgpu .warpgroup_barrier ()
140+ other_sem .signal (1 )
141+ my_sem .wait (1 )
142+ ctx .async_copy (src_ref = tmp , dst_ref = dst , gmem_peer_id = other_device , reduction_op = "add" )
143+ ctx .await_async_copy (0 )
144+ other_sem .signal (1 )
145+ my_sem .wait (1 )
146+
147+ mesh = jax .make_mesh (
148+ (2 ,), ("x" ,), axis_types = (jax .sharding .AxisType .Explicit ,)
149+ )
150+ with jax .set_mesh (mesh ):
151+ x_np = np .arange (64 * 64 , dtype = jnp .float32 ).reshape (64 , 64 )
152+ x = jax .sharding .reshard (x_np , P ("x" ))
153+ sem = jax .sharding .reshard (jnp .zeros ((1 ,), dtype = jnp .int32 ), P ())
154+ y , _ = jax .jit (
155+ jax .shard_map (
156+ lambda x , sem : mgpu .as_gpu_kernel (
157+ kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , x , (x , mgpu .TMABarrier ()), inout_shape = sem
158+ )(x , sem ),
159+ in_specs = (P ("x" ), P (None )),
160+ out_specs = [P ("x" ), P (None )],
161+ check_vma = False ,
162+ )
163+ )(x , sem )
164+ y_np = multihost_utils .process_allgather (y , tiled = True )
165+ np .testing .assert_array_equal (
166+ y_np , 1 + np .concatenate (np .split (x_np , 2 )[::- 1 ], axis = 0 )
167+ )
168+
127169 def test_remote_semaphore (self ):
128170 i32 = ir .IntegerType .get_signless (32 )
129171 def kernel (ctx , sem , _ ):
0 commit comments