|
| 1 | +""" |
| 2 | +One-Shot All-Reduce + Bias + RMS Norm Fusion Example |
| 3 | +===================================================== |
| 4 | +This example demonstrates how to implement a fused one-shot all-reduce with bias |
| 5 | +addition and RMS normalization using Helion and PyTorch's distributed capabilities. |
| 6 | +It includes a Helion kernel demonstrating how to use symm_mem_sync Triton kernel for |
| 7 | +cross-device synchronization and torch.ops.symm_mem.get_remote_tensors for accessing symmetric |
| 8 | +memory tensors on peer devices. |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import os |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.distributed as dist |
| 17 | +import torch.distributed._symmetric_memory as symm_mem |
| 18 | + |
| 19 | +from examples.distributed.utils import symm_mem_sync |
| 20 | + |
| 21 | +import helion |
| 22 | +from helion._testing import DEVICE |
| 23 | +from helion._testing import run_example |
| 24 | +import helion.language as hl |
| 25 | + |
| 26 | + |
| 27 | +@helion.jit( |
| 28 | + config=helion.Config( |
| 29 | + block_sizes=[8], |
| 30 | + num_warps=8, |
| 31 | + ), |
| 32 | + static_shapes=True, |
| 33 | +) |
| 34 | +def one_shot_allreduce_bias_rmsnorm_kernel( |
| 35 | + x: torch.Tensor, |
| 36 | + symm_mem_buffer: torch.Tensor, |
| 37 | + bias: torch.Tensor, |
| 38 | + weight: torch.Tensor, |
| 39 | + signal_pad_ptrs: torch.Tensor, |
| 40 | + EPS: hl.constexpr, |
| 41 | + RANK: hl.constexpr, |
| 42 | + WORLD_SIZE: hl.constexpr, |
| 43 | + GROUP_NAME: hl.constexpr, |
| 44 | +) -> torch.Tensor: |
| 45 | + """ |
| 46 | + Fused one-shot all-reduce + bias addition + RMS normalization. |
| 47 | + """ |
| 48 | + N, D = x.size() |
| 49 | + output = torch.empty_like(x) |
| 50 | + |
| 51 | + # Get remote buffers from all ranks (views into each rank's symm_mem_buffer) |
| 52 | + buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, GROUP_NAME) |
| 53 | + |
| 54 | + for tile_n in hl.tile(N): |
| 55 | + # Step 1: Copy input x to our symmetric memory buffer |
| 56 | + symm_mem_buffer[tile_n, :] = x[tile_n, :] |
| 57 | + |
| 58 | + # Step 2: Sync with hasPreviousMemAccess=True hasSubsequentMemAccess=True |
| 59 | + # - release fence: ensures our write to symm_mem_buffer is visible to other ranks |
| 60 | + # - acquire fence: ensures we see other ranks' writes to their buffers |
| 61 | + hl.triton_kernel( |
| 62 | + symm_mem_sync, |
| 63 | + args=(signal_pad_ptrs, tile_n.id, RANK, WORLD_SIZE, True, True), |
| 64 | + output_like=None, |
| 65 | + ) |
| 66 | + |
| 67 | + # Step 3: All-reduce + bias: acc = bias + sum(buffer from all ranks) |
| 68 | + # Initialize acc with the right shape by broadcasting bias |
| 69 | + acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to( |
| 70 | + torch.float32 |
| 71 | + ) |
| 72 | + for remote_buffer in buffer_tuple: |
| 73 | + acc = acc + remote_buffer[tile_n, :].to(torch.float32) |
| 74 | + |
| 75 | + # Step 4: RMS Norm: y = acc * rsqrt(mean(acc^2) + eps) * weight |
| 76 | + variance = torch.mean(acc * acc, dim=-1, keepdim=True) |
| 77 | + rstd = torch.rsqrt(variance + EPS) # type: ignore[unsupported-operation] |
| 78 | + normalized = acc * rstd |
| 79 | + output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype) |
| 80 | + |
| 81 | + # Step 5: Final sync (release only) |
| 82 | + hl.triton_kernel( |
| 83 | + symm_mem_sync, |
| 84 | + args=(signal_pad_ptrs, tile_n.id, RANK, WORLD_SIZE, True, False), |
| 85 | + output_like=None, |
| 86 | + ) |
| 87 | + |
| 88 | + return output |
| 89 | + |
| 90 | + |
| 91 | +def helion_one_shot_allreduce_bias_rmsnorm( |
| 92 | + x: torch.Tensor, # Regular input tensor |
| 93 | + bias: torch.Tensor, |
| 94 | + weight: torch.Tensor, |
| 95 | + eps: float = 1e-5, |
| 96 | +) -> torch.Tensor: |
| 97 | + """ |
| 98 | + Wrapper that sets up symmetric memory and calls the Helion kernel. |
| 99 | + """ |
| 100 | + group = dist.group.WORLD |
| 101 | + if group is None: |
| 102 | + raise RuntimeError("Distributed group is not initialized") |
| 103 | + |
| 104 | + N, D = x.shape |
| 105 | + |
| 106 | + symm_mem_buffer = symm_mem.empty(N, D, dtype=x.dtype, device=x.device) |
| 107 | + symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, group.group_name) |
| 108 | + |
| 109 | + return one_shot_allreduce_bias_rmsnorm_kernel( |
| 110 | + x, |
| 111 | + symm_mem_buffer, |
| 112 | + bias, |
| 113 | + weight, |
| 114 | + symm_mem_hdl.signal_pad_ptrs_dev, |
| 115 | + EPS=eps, |
| 116 | + RANK=symm_mem_hdl.rank, |
| 117 | + WORLD_SIZE=symm_mem_hdl.world_size, |
| 118 | + GROUP_NAME=group.group_name, |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +def reference_one_shot_allreduce_bias_rmsnorm( |
| 123 | + x: torch.Tensor, |
| 124 | + bias: torch.Tensor, |
| 125 | + weight: torch.Tensor, |
| 126 | + eps: float = 1e-5, |
| 127 | +) -> torch.Tensor: |
| 128 | + x_reduced = x.clone() |
| 129 | + dist.all_reduce(x_reduced) |
| 130 | + x_with_bias = x_reduced + bias |
| 131 | + |
| 132 | + # RMS Norm |
| 133 | + variance = x_with_bias.to(torch.float32).pow(2).mean(-1, keepdim=True) |
| 134 | + rstd = torch.rsqrt(variance + eps) |
| 135 | + normalized = x_with_bias.to(torch.float32) * rstd |
| 136 | + return (normalized * weight.to(torch.float32)).to(x.dtype) |
| 137 | + |
| 138 | + |
| 139 | +def test(N: int, D: int, device: torch.device, dtype: torch.dtype) -> None: |
| 140 | + """Test the Helion implementation against the reference.""" |
| 141 | + rank = dist.get_rank() |
| 142 | + |
| 143 | + torch.manual_seed(42 + rank) |
| 144 | + x = torch.randn(N, D, dtype=dtype, device=device) |
| 145 | + |
| 146 | + torch.manual_seed(42) |
| 147 | + bias = torch.randn(D, dtype=dtype, device=device) |
| 148 | + weight = torch.randn(D, dtype=dtype, device=device) |
| 149 | + |
| 150 | + run_example( |
| 151 | + helion_one_shot_allreduce_bias_rmsnorm, |
| 152 | + reference_one_shot_allreduce_bias_rmsnorm, |
| 153 | + (x, bias, weight), |
| 154 | + rtol=1e-4, |
| 155 | + atol=1e-4, |
| 156 | + ) |
| 157 | + |
| 158 | + |
| 159 | +def main() -> None: |
| 160 | + symm_mem.set_backend("NVSHMEM") |
| 161 | + rank = int(os.environ["LOCAL_RANK"]) |
| 162 | + torch.manual_seed(42 + rank) |
| 163 | + device = torch.device(f"cuda:{rank}") |
| 164 | + torch.cuda.set_device(device) |
| 165 | + dist.init_process_group("nccl") |
| 166 | + symm_mem.enable_symm_mem_for_group( |
| 167 | + dist.group.WORLD.group_name # type: ignore[missing-attribute] |
| 168 | + ) |
| 169 | + |
| 170 | + test(N=128, D=4096, device=device, dtype=torch.float32) |
| 171 | + |
| 172 | + dist.destroy_process_group() |
| 173 | + |
| 174 | + |
| 175 | +if __name__ == "__main__": |
| 176 | + """ |
| 177 | + Run with: |
| 178 | + python -m torch.distributed.run --standalone \ |
| 179 | + --nproc-per-node 4 \ |
| 180 | + --rdzv-backend c10d --rdzv-endpoint localhost:0 \ |
| 181 | + examples/distributed/one_shot_allreduce_bias_rmsnorm.py |
| 182 | + """ |
| 183 | + assert DEVICE.type == "cuda", "Requires CUDA device" |
| 184 | + main() |
0 commit comments