Skip to content

Commit f5a09b8

Browse files
[Mosaic GPU] Add support for all kinds of TMA reductions.
I had to change the tma descriptor cache key, since there are cases where we currently need two different descriptors based on the reduction op. We could in principle go back to a single TMA descriptor in those cases if we pass sign information to async_copy. PiperOrigin-RevId: 845269465
1 parent 6860386 commit f5a09b8

File tree

4 files changed

+184
-67
lines changed

4 files changed

+184
-67
lines changed

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 100 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,24 @@
3939

4040
TMA_DESCRIPTOR_BYTES = 128
4141
TMA_DESCRIPTOR_ALIGNMENT = 64
42-
TMAReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"]
42+
TMAReductionOp = Literal[
43+
"add",
44+
"min",
45+
"max",
46+
"inc",
47+
"dec",
48+
"and",
49+
"or",
50+
"xor",
51+
"umin",
52+
"umax",
53+
"smin",
54+
"smax",
55+
]
56+
57+
def _reduction_op_to_ptx(reduction_op: TMAReductionOp) -> str:
58+
# convert [s|u]min|max to min|max
59+
return reduction_op[-3:]
4360

4461
c = utils.c # This is too common to fully qualify.
4562

@@ -426,6 +443,81 @@ def _find_kernel_argument_for_gmem_ref(
426443
return gmem_ref
427444

428445

446+
def _is_tma_reduction_op_supported(
447+
reduction_op: TMAReductionOp | None, dtype: ir.Type,
448+
) -> bool:
449+
"""Returns whether the given TMA reduction op supports the given dtype.
450+
451+
This function essentially implements the table at:
452+
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor
453+
with the following differences:
454+
- For `add` reductions, we also support int64, treating it as uint64.
455+
- For `and`, `or`, and `xor` reductions, we support signed integer types.
456+
- For `inc` and `dec` reductions, we support both signed and unsigned i32
457+
treating both as unsigned.
458+
"""
459+
i32 = ir.IntegerType.get_signless(32)
460+
i64 = ir.IntegerType.get_signless(64)
461+
f16 = ir.F16Type.get()
462+
f32 = ir.F32Type.get()
463+
bf16 = ir.BF16Type.get()
464+
465+
match reduction_op:
466+
case None:
467+
return True
468+
case "add":
469+
return dtype in (f16, f32, bf16, i32, i64)
470+
case "max" | "min":
471+
return dtype in (f16, bf16)
472+
case "umax" | "umin" | "smax" | "smin":
473+
return dtype in (i32, i64)
474+
case "inc" | "dec":
475+
return dtype == i32
476+
case "and" | "or" | "xor":
477+
return dtype in (i32, i64)
478+
479+
480+
def _tma_dma_type(
481+
element_type: ir.Type,
482+
reduction_op: TMAReductionOp | None,
483+
) -> int:
484+
"""Returns the TMA DMA type for the given element type and signedness."""
485+
if ir.IntegerType.isinstance(element_type):
486+
bitwidth = utils.bitwidth_impl(element_type)
487+
if bitwidth == 2:
488+
tma_dtype = 8
489+
elif bitwidth == 4:
490+
tma_dtype = 0
491+
elif bitwidth == 8:
492+
tma_dtype = 1
493+
elif bitwidth == 16:
494+
tma_dtype = 2
495+
elif bitwidth == 32:
496+
tma_dtype = 9 if reduction_op in ("smin", "smax") else 3
497+
elif bitwidth == 64:
498+
tma_dtype = 10 if reduction_op in ("smin", "smax") else 4
499+
else:
500+
raise ValueError(f"Unsupported integer bitwidth: {bitwidth}")
501+
elif ir.F16Type.isinstance(element_type):
502+
tma_dtype = 5
503+
elif ir.F32Type.isinstance(element_type):
504+
tma_dtype = 6
505+
elif ir.BF16Type.isinstance(element_type):
506+
tma_dtype = 7
507+
# We treat narrow floats as integers
508+
elif ir.Float8E5M2Type.isinstance(element_type):
509+
tma_dtype = 1
510+
elif ir.Float8E4M3FNType.isinstance(element_type):
511+
tma_dtype = 1
512+
elif ir.Float8E8M0FNUType.isinstance(element_type):
513+
tma_dtype = 1
514+
elif ir.Float4E2M1FNType.isinstance(element_type):
515+
tma_dtype = 0
516+
else:
517+
raise ValueError(f"unsupported TMA dtype {element_type}")
518+
return tma_dtype
519+
520+
429521
class AsyncCopyImplementation(enum.Enum):
430522
TMA = enum.auto()
431523
CP_ASYNC = enum.auto()
@@ -438,7 +530,7 @@ class LaunchContext:
438530
cluster_size: tuple[int, int, int]
439531
profiler: OnDeviceProfiler | None = None
440532
tma_descriptors: dict[
441-
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...], Any],
533+
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...], Any, int],
442534
ir.Value,
443535
] = dataclasses.field(default_factory=dict, init=False)
444536
is_device_collective: bool = False
@@ -512,10 +604,11 @@ def _get_tma_desc(
512604
reduction_op: TMAReductionOp | None,
513605
):
514606
gmem_ref = _find_kernel_argument_for_gmem_ref(gmem_ref)
607+
tma_dtype = _tma_dma_type(ir.MemRefType(gmem_ref.type).element_type, reduction_op)
515608
# Using ir.Values in cache keys is a little sketchy, but I think it should
516609
# be fine. Having it in the key will keep it alive, and if comparison and
517610
# hashing is by identity then it should work out.
518-
tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id)
611+
tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id, tma_dtype)
519612
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
520613
i32 = ir.IntegerType.get_signless(32)
521614
i64 = ir.IntegerType.get_signless(64)
@@ -580,43 +673,6 @@ def init_tma_desc(host_ptr):
580673
)
581674
# TODO(apaszke): Better verification (e.g. slice is non-zero)
582675
# TODO(apaszke): We always know strides statically.
583-
if isinstance(ref_ty.element_type, ir.IntegerType):
584-
if reduction_op is not None:
585-
raise ValueError(
586-
f"TMA with reduction_op={reduction_op} is not supported with Integers"
587-
)
588-
bitwidth = utils.bitwidth_impl(ref_ty.element_type)
589-
if bitwidth == 2:
590-
tma_dtype = 8
591-
elif bitwidth == 4:
592-
tma_dtype = 0
593-
elif bitwidth == 8:
594-
tma_dtype = 1
595-
elif bitwidth == 16:
596-
tma_dtype = 2
597-
elif bitwidth == 32:
598-
tma_dtype = 3
599-
elif bitwidth == 64:
600-
tma_dtype = 4
601-
else:
602-
raise ValueError(f"Unsupported integer bitwidth: {bitwidth}")
603-
elif ir.F16Type.isinstance(ref_ty.element_type):
604-
tma_dtype = 5
605-
elif ir.F32Type.isinstance(ref_ty.element_type):
606-
tma_dtype = 6
607-
elif ir.BF16Type.isinstance(ref_ty.element_type):
608-
tma_dtype = 7
609-
# We treat narrow floats as integers
610-
elif ir.Float8E5M2Type.isinstance(ref_ty.element_type):
611-
tma_dtype = 1
612-
elif ir.Float8E4M3FNType.isinstance(ref_ty.element_type):
613-
tma_dtype = 1
614-
elif ir.Float8E8M0FNUType.isinstance(ref_ty.element_type):
615-
tma_dtype = 1
616-
elif ir.Float4E2M1FNType.isinstance(ref_ty.element_type):
617-
tma_dtype = 0
618-
else:
619-
raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}")
620676
dtype_or_bitwidth = c(tma_dtype, i64)
621677
args = [
622678
host_ptr,
@@ -953,16 +1009,10 @@ def async_copy(
9531009
if reduction_op is not None:
9541010
if implementation != AsyncCopyImplementation.TMA:
9551011
raise ValueError("Only the TMA implementation supports reductions")
956-
if not any(
957-
t.isinstance(element_type)
958-
for t in (ir.F32Type, ir.BF16Type, ir.F16Type)
959-
):
960-
raise ValueError(
961-
"TMA with reduction is only supported with f32, f16 and bf16"
962-
)
963-
if reduction_op != "add":
1012+
if not _is_tma_reduction_op_supported(reduction_op, element_type):
9641013
raise ValueError(
965-
"TMA with reduction is only supported with add operation"
1014+
f"Reduction op {reduction_op} not supported by the TMA"
1015+
f" implementation for element type {element_type}"
9661016
)
9671017

9681018
if src_ref_ty.memory_space is None and utils.is_smem_ref(dst_ref_ty):
@@ -1329,7 +1379,7 @@ def async_copy(
13291379
llvm.inline_asm(
13301380
ir.Type.parse("!llvm.void"),
13311381
[predicate,smem_ptr,tma_desc,*rev_dyn_base_indices],
1332-
f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];",
1382+
f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{_reduction_op_to_ptx(reduction_op)}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];",
13331383
"b,r,l" + ",r" * rank,
13341384
has_side_effects=True,
13351385
)

jaxlib/mosaic/dialect/gpu/mosaic_gpu.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,11 @@ def MosaicGPU_TMAReduction : I32EnumAttr<"TMAReduction",
205205
I32EnumAttrCase<"Dec", 4, "dec">,
206206
I32EnumAttrCase<"And", 5, "and">,
207207
I32EnumAttrCase<"Or", 6, "or">,
208-
I32EnumAttrCase<"Xor", 7, "xor">
208+
I32EnumAttrCase<"Xor", 7, "xor">,
209+
I32EnumAttrCase<"Umin", 8, "umin">,
210+
I32EnumAttrCase<"Umax", 9, "umax">,
211+
I32EnumAttrCase<"Smin", 10, "smin">,
212+
I32EnumAttrCase<"Smax", 11, "smax">
209213
]>{
210214
let cppNamespace = "::mosaic_gpu";
211215
}

jaxlib/mosaic/gpu/runtime.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
4747

4848
CUtensorMapDataType data_type;
4949
int64_t elem_bitwidth;
50-
// types are defined in: LaunchContext._get_tma_desc()
50+
// types are defined in: launch_context._tma_dma_type()
5151
if (elem_type == 8){
5252
// this is for int2s
5353
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
@@ -77,7 +77,13 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
7777
} else if (elem_type == 7){
7878
data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
7979
elem_bitwidth = 16;
80-
} else{
80+
} else if (elem_type == 9){
81+
data_type = CU_TENSOR_MAP_DATA_TYPE_INT32;
82+
elem_bitwidth = 32;
83+
} else if (elem_type == 10){
84+
data_type = CU_TENSOR_MAP_DATA_TYPE_INT64;
85+
elem_bitwidth = 64;
86+
} else{
8187
fprintf(stderr, "Unsupported element type: %ld \n", elem_type);
8288
abort();
8389
}

tests/mosaic/gpu_test.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5397,29 +5397,57 @@ def body(
53975397
x = self.prng.uniform(0, 10, input_shape).astype(el_type)
53985398
self.assertArraysEqual(kernel(x), x.reshape(output_shape))
53995399

5400-
@parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.float16)
5401-
def test_async_store_add_reduction(self, dtype):
5402-
# TODO(b/415721295):Remove after the minimal jaxlib version is 0.8.2.
5400+
@parameterized.product(
5401+
dtype=(jnp.int32, jnp.int64, jnp.uint32, jnp.uint64, jnp.float32, jnp.float16, jnp.bfloat16),
5402+
reduction_op=("add", "min", "max", "inc", "dec", "and", "or", "xor"),
5403+
)
5404+
def test_async_store_reduction(self, dtype, reduction_op):
5405+
# TODO(b/415721295):Clean up after the minimal jaxlib version is 0.8.2.
54035406
if not hasattr(mgpu_dialect, "TMAReduction"):
5404-
self.skipTest("TMAReduction op is required.")
5407+
self.skipTest("The mgpu_dialect.TMAReduction attribute is required.")
5408+
5409+
if reduction_op in ("min", "max"):
5410+
if dtype in (jnp.int32, jnp.int64):
5411+
reduction_op = "s" + reduction_op
5412+
elif dtype in (jnp.uint32, jnp.uint64):
5413+
reduction_op = "u" + reduction_op
5414+
5415+
if reduction_op in ("smin", "smax", "umin", "umax") and not hasattr(mgpu_dialect.TMAReduction, "Smin"):
5416+
self.skipTest("The Smin/Smax/Umin/Umax reduction types are required.")
5417+
5418+
if (
5419+
not launch_context._is_tma_reduction_op_supported(
5420+
reduction_op,
5421+
utils.dtype_to_ir_type(dtype),
5422+
)
5423+
or (
5424+
dtype in (jnp.uint32, jnp.uint64)
5425+
and reduction_op in ("smin", "smax")
5426+
)
5427+
or (
5428+
dtype in (jnp.int32, jnp.int64) and reduction_op in ("umin", "umax")
5429+
)
5430+
or dtype == jnp.int32 and reduction_op in ("inc", "dec")
5431+
):
5432+
self.skipTest("TMA does not support this reduction op for this dtype")
54055433

54065434
shape = (8, 128)
54075435

54085436
def body(ctx, src, dst, smem):
54095437
del ctx
5410-
smem_ref, tma_barrier = smem
5438+
src_smem_ref, tma_barrier = smem
54115439
i32 = ir.IntegerType.get_signless(32)
54125440
zero = arith.constant(i32, 0)
54135441
indices = [zero, zero]
5414-
slice_lengths = smem_ref.type.shape
5442+
slice_lengths = src_smem_ref.type.shape
54155443

54165444
tma_barrier.arrive_expect_tx(
5417-
utils.bitwidth(smem_ref.type.element_type) * math.prod(shape) // 8
5445+
utils.bitwidth(src_smem_ref.type.element_type) * math.prod(shape) // 8
54185446
)
54195447

54205448
mgpu_dialect.async_load(
54215449
source=src,
5422-
destination=smem_ref,
5450+
destination=src_smem_ref,
54235451
barrier=tma_barrier.as_barrier_memref(),
54245452
indices=indices,
54255453
slice_lengths=slice_lengths,
@@ -5428,31 +5456,60 @@ def body(ctx, src, dst, smem):
54285456

54295457
tma_barrier.wait()
54305458

5459+
reduction_attr = getattr(
5460+
mgpu_dialect.TMAReduction, reduction_op.capitalize()
5461+
)
5462+
54315463
mgpu_dialect.async_store(
5432-
source=smem_ref,
5464+
source=src_smem_ref,
54335465
destination=dst,
54345466
indices=indices,
54355467
slice_lengths=slice_lengths,
5436-
reduction_op=mgpu_dialect.TMAReduction.Add,
5468+
reduction_op=reduction_attr,
54375469
)
54385470
nvvm.cp_async_bulk_wait_group(0)
54395471

5440-
src = jnp.ones(shape, dtype=dtype)
5441-
dst = jnp.ones(shape, dtype=dtype)
5472+
prng_key = jax.random.key(1234)
5473+
k0, k1 = jax.random.split(prng_key, 2)
5474+
if dtype in (jnp.bfloat16, jnp.float16, jnp.float32):
5475+
src = jax.random.uniform(k0, shape, dtype, -10, 10)
5476+
dst = jax.random.uniform(k1, shape, dtype, -10, 10)
5477+
else:
5478+
src = jax.random.randint(k0, shape, -10, 10).astype(dtype)
5479+
dst = jax.random.randint(k1, shape, -10, 10).astype(dtype)
5480+
5481+
if reduction_op == "add":
5482+
expected = src + dst
5483+
elif reduction_op in ("min", "smin", "umin"):
5484+
expected = jnp.minimum(src, dst)
5485+
elif reduction_op in ("max", "smax", "umax"):
5486+
expected = jnp.maximum(src, dst)
5487+
elif reduction_op == "and":
5488+
expected = src & dst
5489+
elif reduction_op == "or":
5490+
expected = src | dst
5491+
elif reduction_op == "xor":
5492+
expected = src ^ dst
5493+
elif reduction_op == "inc":
5494+
expected = jnp.where(dst >= src, 0, dst + 1)
5495+
elif reduction_op == "dec":
5496+
expected = jnp.where((dst == 0) | (dst > src), src, dst - 1)
5497+
else:
5498+
raise ValueError(f"Unsupported reduction op: {reduction_op}")
54425499

54435500
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
54445501
kernel = mgpu.as_gpu_kernel(
54455502
body,
54465503
grid=(1, 1, 1),
54475504
block=(128, 1, 1),
5448-
in_shape=(jax_shape,),
5505+
in_shape=(jax_shape),
54495506
out_shape=(),
54505507
inout_shape=(jax_shape,),
54515508
smem_scratch_shape=[jax_shape, core.TMABarrier(1)],
54525509
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
54535510
)
54545511

5455-
np.testing.assert_array_equal(kernel(src, dst)[0], src + dst)
5512+
np.testing.assert_array_equal(kernel(src, dst)[0], expected)
54565513

54575514

54585515
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):

0 commit comments

Comments
 (0)