diff --git a/iris/iris.py b/iris/iris.py index 13e8c51f..0f1073b3 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1237,31 +1237,16 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async @triton.jit -def __translate(ptr, from_rank, to_rank, heap_bases): +def __translate(ptr, from_rank, to_rank, heap_bases, hint: tl.constexpr = None): from_base = tl.load(heap_bases + from_rank) to_base = tl.load(heap_bases + to_rank) - # convert to int to compute difference ptr_int = tl.cast(ptr, tl.uint64) - # Find the offset from from_rank heap offset = ptr_int - from_base - # Byte cast for byte offset addition to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) - # Find the offset into the to_rank heap translated_ptr_byte = to_base_byte + offset - # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - - # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor or block sizes - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) - - # 0 You can use this if your block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits) - # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - # translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) - - # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) + if hint is not None: + translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, hint), hint) return translated_ptr @@ -1438,12 +1423,12 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False): return DeviceContext(rank, world_size, heap_bases, device_tracing) @triton.jit - def _translate(self, ptr, from_rank, to_rank): + def _translate(self, ptr, from_rank, to_rank, hint: tl.constexpr = None): """Internal pointer translation between rank address spaces.""" - return __translate(ptr, from_rank, to_rank, self.heap_bases) + return __translate(ptr, from_rank, to_rank, self.heap_bases, hint) @triton.jit - def load(self, pointer, from_rank, mask=None): + def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): """ Loads a value from the specified rank's memory location. @@ -1456,6 +1441,7 @@ def load(self, pointer, from_rank, mask=None): pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None. Returns: Block: The loaded value from the target memory location. @@ -1463,12 +1449,12 @@ def load(self, pointer, from_rank, mask=None): Example: >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, from_rank) + translated_ptr = self._translate(pointer, self.rank, from_rank, hint) result = tl.load(translated_ptr, mask=mask) return result @triton.jit - def store(self, pointer, value, to_rank, mask=None): + def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -1489,11 +1475,11 @@ def store(self, pointer, value, to_rank, mask=None): Example: >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) tl.store(translated_ptr, value, mask=mask) @triton.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None): + def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory into current rank's local memory. @@ -1514,12 +1500,12 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None): Example: >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, from_rank=1, mask=mask) """ - translated_from_ptr = self._translate(from_ptr, self.rank, from_rank) + translated_from_ptr = self._translate(from_ptr, self.rank, from_rank, hint) data = tl.load(translated_from_ptr, mask=mask) tl.store(to_ptr, data, mask=mask) @triton.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None): + def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): """ Copies data from current rank's local memory to the specified rank's memory. @@ -1540,12 +1526,12 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None): Example: >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, to_rank=1, mask=mask) """ - translated_to_ptr = self._translate(to_ptr, self.rank, to_rank) + translated_to_ptr = self._translate(to_ptr, self.rank, to_rank, hint) data = tl.load(from_ptr, mask=mask) tl.store(translated_to_ptr, data, mask=mask) @triton.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): + def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constexpr = None): """ Copies data from one rank's memory to another rank's memory. @@ -1585,11 +1571,15 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + data = tl.load(translated_src, mask=mask) tl.store(translated_dst, data, mask=mask) @triton.jit - def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic add at the specified rank's memory location. @@ -1612,11 +1602,11 @@ def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Example: >>> old_val = ctx.atomic_add(counter, 1, to_rank=1) """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Atomically subtracts data from the specified rank's memory location. @@ -1636,11 +1626,11 @@ def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None): + def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic compare-and-swap at the specified rank's memory location. @@ -1661,11 +1651,11 @@ def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @triton.jit - def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic exchange at the specified rank's memory location. @@ -1685,11 +1675,11 @@ def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic XOR at the specified rank's memory location. @@ -1709,11 +1699,11 @@ def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic AND at the specified rank's memory location. @@ -1733,11 +1723,11 @@ def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic OR at the specified rank's memory location. @@ -1757,11 +1747,11 @@ def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic minimum at the specified rank's memory location. @@ -1781,11 +1771,11 @@ def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic maximum at the specified rank's memory location. @@ -1805,12 +1795,12 @@ def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None): +def load(pointer, to_rank, from_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Loads a value from the specified rank's memory location. @@ -1825,6 +1815,7 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): from_rank (int): The rank ID from which to read the data. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: Block: The loaded value from the target memory location. @@ -1838,13 +1829,13 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) >>> return data """ - translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) + translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases, hint) result = tl.load(translated_ptr, mask=mask) return result @triton.jit -def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): +def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -1860,6 +1851,7 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1873,12 +1865,12 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): >>> value = 42 >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) tl.store(translated_ptr, value, mask=mask) @triton.jit -def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): +def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory into the destination rank's memory. This function performs the transfer by translating `src_ptr` from the `from_rank`'s address @@ -1895,6 +1887,7 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointers. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1924,12 +1917,16 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + data = tl.load(translated_src, mask=mask) tl.store(translated_dst, data, mask=mask) @triton.jit -def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -1945,6 +1942,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The current rank ID where the data will be stored. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1956,7 +1954,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): >>> to_rank = 0 >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) """ - translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) + translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(translated_from_ptr, mask=mask) @@ -1964,7 +1962,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -1979,6 +1977,7 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The `to_rank` ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1990,7 +1989,7 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): >>> to_rank = 1 >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ - translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(from_ptr, mask=mask) @@ -1998,7 +1997,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_add( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic add at the specified rank's memory location. @@ -2016,6 +2017,7 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2029,12 +2031,14 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> increment = 5 >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_sub( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Atomically subtracts data from the specified rank's memory location. @@ -2052,6 +2056,7 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The value at the memory location before the atomic subtraction. @@ -2065,12 +2070,12 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> decrement = 3 >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): +def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None, hint: tl.constexpr = None): """ Atomically compares and exchanges the specified rank's memory location. @@ -2088,6 +2093,7 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The value contained at the memory location before the atomic operation attempt. @@ -2102,12 +2108,14 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop >>> new_val = 42 >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @triton.jit -def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xchg( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic exchange at the specified rank's memory location. @@ -2125,6 +2133,7 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2138,12 +2147,14 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non >>> new_value = 99 >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xor( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic xor at the specified rank's memory location. @@ -2161,6 +2172,7 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2174,12 +2186,14 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0xFF >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_and( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic and at the specified rank's memory location. @@ -2197,6 +2211,7 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2210,12 +2225,12 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0x0F >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic or at the specified rank's memory location. @@ -2233,6 +2248,7 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2246,12 +2262,14 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, >>> mask_val = 0xF0 >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_min( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic min at the specified rank's memory location. @@ -2269,6 +2287,7 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2282,12 +2301,14 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 10 >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_max( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic max at the specified rank's memory location. @@ -2305,6 +2326,7 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2318,7 +2340,7 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 100 >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope)