3939
4040TMA_DESCRIPTOR_BYTES = 128
4141TMA_DESCRIPTOR_ALIGNMENT = 64
42- TMAReductionOp = Literal ["add" , "min" , "max" , "inc" , "dec" , "and" , "or" , "xor" ]
42+ TMAReductionOp = Literal [
43+ "add" ,
44+ "min" ,
45+ "max" ,
46+ "inc" ,
47+ "dec" ,
48+ "and" ,
49+ "or" ,
50+ "xor" ,
51+ "umin" ,
52+ "umax" ,
53+ "smin" ,
54+ "smax" ,
55+ ]
56+
57+ def _reduction_op_to_ptx (reduction_op : TMAReductionOp ) -> str :
58+ # convert [s|u]min|max to min|max
59+ return reduction_op [- 3 :]
4360
4461c = utils .c # This is too common to fully qualify.
4562
@@ -426,6 +443,81 @@ def _find_kernel_argument_for_gmem_ref(
426443 return gmem_ref
427444
428445
446+ def _is_tma_reduction_op_supported (
447+ reduction_op : TMAReductionOp | None , dtype : ir .Type ,
448+ ) -> bool :
449+ """Returns whether the given TMA reduction op supports the given dtype.
450+
451+ This function essentially implements the table at:
452+ https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor
453+ with the following differences:
454+ - For `add` reductions, we also support int64, treating it as uint64.
455+ - For `and`, `or`, and `xor` reductions, we support signed integer types.
456+ - For `inc` and `dec` reductions, we support both signed and unsigned i32
457+ treating both as unsigned.
458+ """
459+ i32 = ir .IntegerType .get_signless (32 )
460+ i64 = ir .IntegerType .get_signless (64 )
461+ f16 = ir .F16Type .get ()
462+ f32 = ir .F32Type .get ()
463+ bf16 = ir .BF16Type .get ()
464+
465+ match reduction_op :
466+ case None :
467+ return True
468+ case "add" :
469+ return dtype in (f16 , f32 , bf16 , i32 , i64 )
470+ case "max" | "min" :
471+ return dtype in (f16 , bf16 )
472+ case "umax" | "umin" | "smax" | "smin" :
473+ return dtype in (i32 , i64 )
474+ case "inc" | "dec" :
475+ return dtype == i32
476+ case "and" | "or" | "xor" :
477+ return dtype in (i32 , i64 )
478+
479+
480+ def _tma_dma_type (
481+ element_type : ir .Type ,
482+ reduction_op : TMAReductionOp | None ,
483+ ) -> int :
484+ """Returns the TMA DMA type for the given element type and signedness."""
485+ if ir .IntegerType .isinstance (element_type ):
486+ bitwidth = utils .bitwidth_impl (element_type )
487+ if bitwidth == 2 :
488+ tma_dtype = 8
489+ elif bitwidth == 4 :
490+ tma_dtype = 0
491+ elif bitwidth == 8 :
492+ tma_dtype = 1
493+ elif bitwidth == 16 :
494+ tma_dtype = 2
495+ elif bitwidth == 32 :
496+ tma_dtype = 9 if reduction_op in ("smin" , "smax" ) else 3
497+ elif bitwidth == 64 :
498+ tma_dtype = 10 if reduction_op in ("smin" , "smax" ) else 4
499+ else :
500+ raise ValueError (f"Unsupported integer bitwidth: { bitwidth } " )
501+ elif ir .F16Type .isinstance (element_type ):
502+ tma_dtype = 5
503+ elif ir .F32Type .isinstance (element_type ):
504+ tma_dtype = 6
505+ elif ir .BF16Type .isinstance (element_type ):
506+ tma_dtype = 7
507+ # We treat narrow floats as integers
508+ elif ir .Float8E5M2Type .isinstance (element_type ):
509+ tma_dtype = 1
510+ elif ir .Float8E4M3FNType .isinstance (element_type ):
511+ tma_dtype = 1
512+ elif ir .Float8E8M0FNUType .isinstance (element_type ):
513+ tma_dtype = 1
514+ elif ir .Float4E2M1FNType .isinstance (element_type ):
515+ tma_dtype = 0
516+ else :
517+ raise ValueError (f"unsupported TMA dtype { element_type } " )
518+ return tma_dtype
519+
520+
429521class AsyncCopyImplementation (enum .Enum ):
430522 TMA = enum .auto ()
431523 CP_ASYNC = enum .auto ()
@@ -438,7 +530,7 @@ class LaunchContext:
438530 cluster_size : tuple [int , int , int ]
439531 profiler : OnDeviceProfiler | None = None
440532 tma_descriptors : dict [
441- tuple [ir .Value , tuple [int , ...], int | None , tuple [MemRefTransform , ...], Any ],
533+ tuple [ir .Value , tuple [int , ...], int | None , tuple [MemRefTransform , ...], Any , int ],
442534 ir .Value ,
443535 ] = dataclasses .field (default_factory = dict , init = False )
444536 is_device_collective : bool = False
@@ -512,10 +604,11 @@ def _get_tma_desc(
512604 reduction_op : TMAReductionOp | None ,
513605 ):
514606 gmem_ref = _find_kernel_argument_for_gmem_ref (gmem_ref )
607+ tma_dtype = _tma_dma_type (ir .MemRefType (gmem_ref .type ).element_type , reduction_op )
515608 # Using ir.Values in cache keys is a little sketchy, but I think it should
516609 # be fine. Having it in the key will keep it alive, and if comparison and
517610 # hashing is by identity then it should work out.
518- tma_desc_key = (gmem_ref , transformed_slice_shape , swizzle , gmem_transform , gmem_peer_id )
611+ tma_desc_key = (gmem_ref , transformed_slice_shape , swizzle , gmem_transform , gmem_peer_id , tma_dtype )
519612 if (tma_desc := self .tma_descriptors .get (tma_desc_key , None )) is None :
520613 i32 = ir .IntegerType .get_signless (32 )
521614 i64 = ir .IntegerType .get_signless (64 )
@@ -580,43 +673,6 @@ def init_tma_desc(host_ptr):
580673 )
581674 # TODO(apaszke): Better verification (e.g. slice is non-zero)
582675 # TODO(apaszke): We always know strides statically.
583- if isinstance (ref_ty .element_type , ir .IntegerType ):
584- if reduction_op is not None :
585- raise ValueError (
586- f"TMA with reduction_op={ reduction_op } is not supported with Integers"
587- )
588- bitwidth = utils .bitwidth_impl (ref_ty .element_type )
589- if bitwidth == 2 :
590- tma_dtype = 8
591- elif bitwidth == 4 :
592- tma_dtype = 0
593- elif bitwidth == 8 :
594- tma_dtype = 1
595- elif bitwidth == 16 :
596- tma_dtype = 2
597- elif bitwidth == 32 :
598- tma_dtype = 3
599- elif bitwidth == 64 :
600- tma_dtype = 4
601- else :
602- raise ValueError (f"Unsupported integer bitwidth: { bitwidth } " )
603- elif ir .F16Type .isinstance (ref_ty .element_type ):
604- tma_dtype = 5
605- elif ir .F32Type .isinstance (ref_ty .element_type ):
606- tma_dtype = 6
607- elif ir .BF16Type .isinstance (ref_ty .element_type ):
608- tma_dtype = 7
609- # We treat narrow floats as integers
610- elif ir .Float8E5M2Type .isinstance (ref_ty .element_type ):
611- tma_dtype = 1
612- elif ir .Float8E4M3FNType .isinstance (ref_ty .element_type ):
613- tma_dtype = 1
614- elif ir .Float8E8M0FNUType .isinstance (ref_ty .element_type ):
615- tma_dtype = 1
616- elif ir .Float4E2M1FNType .isinstance (ref_ty .element_type ):
617- tma_dtype = 0
618- else :
619- raise ValueError (f"unsupported TMA dtype { ref_ty .element_type } " )
620676 dtype_or_bitwidth = c (tma_dtype , i64 )
621677 args = [
622678 host_ptr ,
@@ -953,16 +1009,10 @@ def async_copy(
9531009 if reduction_op is not None :
9541010 if implementation != AsyncCopyImplementation .TMA :
9551011 raise ValueError ("Only the TMA implementation supports reductions" )
956- if not any (
957- t .isinstance (element_type )
958- for t in (ir .F32Type , ir .BF16Type , ir .F16Type )
959- ):
960- raise ValueError (
961- "TMA with reduction is only supported with f32, f16 and bf16"
962- )
963- if reduction_op != "add" :
1012+ if not _is_tma_reduction_op_supported (reduction_op , element_type ):
9641013 raise ValueError (
965- "TMA with reduction is only supported with add operation"
1014+ f"Reduction op { reduction_op } not supported by the TMA"
1015+ f" implementation for element type { element_type } "
9661016 )
9671017
9681018 if src_ref_ty .memory_space is None and utils .is_smem_ref (dst_ref_ty ):
@@ -1329,7 +1379,7 @@ def async_copy(
13291379 llvm .inline_asm (
13301380 ir .Type .parse ("!llvm.void" ),
13311381 [predicate ,smem_ptr ,tma_desc ,* rev_dyn_base_indices ],
1332- f"@$0 cp.reduce.async.bulk.tensor.{ rank } d.global.shared::cta.{ reduction_op } .tile.bulk_group [$2,{{{ idx_operands } }}], [$1];" ,
1382+ f"@$0 cp.reduce.async.bulk.tensor.{ rank } d.global.shared::cta.{ _reduction_op_to_ptx ( reduction_op ) } .tile.bulk_group [$2,{{{ idx_operands } }}], [$1];" ,
13331383 "b,r,l" + ",r" * rank ,
13341384 has_side_effects = True ,
13351385 )
0 commit comments