Skip to content

Commit 3757451

Browse files
authored
[Distributed] one_shot_allreduce_bias_rmsnorm example (#1266)
1 parent a0265ef commit 3757451

File tree

7 files changed

+629
-2
lines changed

7 files changed

+629
-2
lines changed

examples/__init__.py

Whitespace-only changes.

examples/distributed/__init__.py

Whitespace-only changes.
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
One-Shot All-Reduce + Bias + RMS Norm Fusion Example
3+
=====================================================
4+
This example demonstrates how to implement a fused one-shot all-reduce with bias
5+
addition and RMS normalization using Helion and PyTorch's distributed capabilities.
6+
It includes a Helion kernel demonstrating how to use symm_mem_sync Triton kernel for
7+
cross-device synchronization and torch.ops.symm_mem.get_remote_tensors for accessing symmetric
8+
memory tensors on peer devices.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import os
14+
15+
import torch
16+
import torch.distributed as dist
17+
import torch.distributed._symmetric_memory as symm_mem
18+
19+
from examples.distributed.utils import symm_mem_sync
20+
21+
import helion
22+
from helion._testing import DEVICE
23+
from helion._testing import run_example
24+
import helion.language as hl
25+
26+
27+
@helion.jit(
28+
config=helion.Config(
29+
block_sizes=[8],
30+
num_warps=8,
31+
),
32+
static_shapes=True,
33+
)
34+
def one_shot_allreduce_bias_rmsnorm_kernel(
35+
x: torch.Tensor,
36+
symm_mem_buffer: torch.Tensor,
37+
bias: torch.Tensor,
38+
weight: torch.Tensor,
39+
signal_pad_ptrs: torch.Tensor,
40+
EPS: hl.constexpr,
41+
RANK: hl.constexpr,
42+
WORLD_SIZE: hl.constexpr,
43+
GROUP_NAME: hl.constexpr,
44+
) -> torch.Tensor:
45+
"""
46+
Fused one-shot all-reduce + bias addition + RMS normalization.
47+
"""
48+
N, D = x.size()
49+
output = torch.empty_like(x)
50+
51+
# Get remote buffers from all ranks (views into each rank's symm_mem_buffer)
52+
buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, GROUP_NAME)
53+
54+
for tile_n in hl.tile(N):
55+
# Step 1: Copy input x to our symmetric memory buffer
56+
symm_mem_buffer[tile_n, :] = x[tile_n, :]
57+
58+
# Step 2: Sync with hasPreviousMemAccess=True hasSubsequentMemAccess=True
59+
# - release fence: ensures our write to symm_mem_buffer is visible to other ranks
60+
# - acquire fence: ensures we see other ranks' writes to their buffers
61+
hl.triton_kernel(
62+
symm_mem_sync,
63+
args=(signal_pad_ptrs, tile_n.id, RANK, WORLD_SIZE, True, True),
64+
output_like=None,
65+
)
66+
67+
# Step 3: All-reduce + bias: acc = bias + sum(buffer from all ranks)
68+
# Initialize acc with the right shape by broadcasting bias
69+
acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to(
70+
torch.float32
71+
)
72+
for remote_buffer in buffer_tuple:
73+
acc = acc + remote_buffer[tile_n, :].to(torch.float32)
74+
75+
# Step 4: RMS Norm: y = acc * rsqrt(mean(acc^2) + eps) * weight
76+
variance = torch.mean(acc * acc, dim=-1, keepdim=True)
77+
rstd = torch.rsqrt(variance + EPS) # type: ignore[unsupported-operation]
78+
normalized = acc * rstd
79+
output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype)
80+
81+
# Step 5: Final sync (release only)
82+
hl.triton_kernel(
83+
symm_mem_sync,
84+
args=(signal_pad_ptrs, tile_n.id, RANK, WORLD_SIZE, True, False),
85+
output_like=None,
86+
)
87+
88+
return output
89+
90+
91+
def helion_one_shot_allreduce_bias_rmsnorm(
92+
x: torch.Tensor, # Regular input tensor
93+
bias: torch.Tensor,
94+
weight: torch.Tensor,
95+
eps: float = 1e-5,
96+
) -> torch.Tensor:
97+
"""
98+
Wrapper that sets up symmetric memory and calls the Helion kernel.
99+
"""
100+
group = dist.group.WORLD
101+
if group is None:
102+
raise RuntimeError("Distributed group is not initialized")
103+
104+
N, D = x.shape
105+
106+
symm_mem_buffer = symm_mem.empty(N, D, dtype=x.dtype, device=x.device)
107+
symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, group.group_name)
108+
109+
return one_shot_allreduce_bias_rmsnorm_kernel(
110+
x,
111+
symm_mem_buffer,
112+
bias,
113+
weight,
114+
symm_mem_hdl.signal_pad_ptrs_dev,
115+
EPS=eps,
116+
RANK=symm_mem_hdl.rank,
117+
WORLD_SIZE=symm_mem_hdl.world_size,
118+
GROUP_NAME=group.group_name,
119+
)
120+
121+
122+
def reference_one_shot_allreduce_bias_rmsnorm(
123+
x: torch.Tensor,
124+
bias: torch.Tensor,
125+
weight: torch.Tensor,
126+
eps: float = 1e-5,
127+
) -> torch.Tensor:
128+
x_reduced = x.clone()
129+
dist.all_reduce(x_reduced)
130+
x_with_bias = x_reduced + bias
131+
132+
# RMS Norm
133+
variance = x_with_bias.to(torch.float32).pow(2).mean(-1, keepdim=True)
134+
rstd = torch.rsqrt(variance + eps)
135+
normalized = x_with_bias.to(torch.float32) * rstd
136+
return (normalized * weight.to(torch.float32)).to(x.dtype)
137+
138+
139+
def test(N: int, D: int, device: torch.device, dtype: torch.dtype) -> None:
140+
"""Test the Helion implementation against the reference."""
141+
rank = dist.get_rank()
142+
143+
torch.manual_seed(42 + rank)
144+
x = torch.randn(N, D, dtype=dtype, device=device)
145+
146+
torch.manual_seed(42)
147+
bias = torch.randn(D, dtype=dtype, device=device)
148+
weight = torch.randn(D, dtype=dtype, device=device)
149+
150+
run_example(
151+
helion_one_shot_allreduce_bias_rmsnorm,
152+
reference_one_shot_allreduce_bias_rmsnorm,
153+
(x, bias, weight),
154+
rtol=1e-4,
155+
atol=1e-4,
156+
)
157+
158+
159+
def main() -> None:
160+
symm_mem.set_backend("NVSHMEM")
161+
rank = int(os.environ["LOCAL_RANK"])
162+
torch.manual_seed(42 + rank)
163+
device = torch.device(f"cuda:{rank}")
164+
torch.cuda.set_device(device)
165+
dist.init_process_group("nccl")
166+
symm_mem.enable_symm_mem_for_group(
167+
dist.group.WORLD.group_name # type: ignore[missing-attribute]
168+
)
169+
170+
test(N=128, D=4096, device=device, dtype=torch.float32)
171+
172+
dist.destroy_process_group()
173+
174+
175+
if __name__ == "__main__":
176+
"""
177+
Run with:
178+
python -m torch.distributed.run --standalone \
179+
--nproc-per-node 4 \
180+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
181+
examples/distributed/one_shot_allreduce_bias_rmsnorm.py
182+
"""
183+
assert DEVICE.type == "cuda", "Requires CUDA device"
184+
main()

examples/distributed/utils.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from __future__ import annotations
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _get_tid(): # noqa: ANN202
9+
return tl.inline_asm_elementwise(
10+
"""
11+
mov.u32 $0, %tid.x;
12+
mov.u32 $1, %tid.y;
13+
mov.u32 $2, %tid.z;
14+
""",
15+
"=r,=r,=r",
16+
[],
17+
dtype=(tl.uint32, tl.uint32, tl.uint32),
18+
is_pure=True,
19+
pack=1,
20+
)
21+
22+
23+
@triton.jit
24+
def _get_ntid(): # noqa: ANN202
25+
return tl.inline_asm_elementwise(
26+
"""
27+
mov.u32 $0, %ntid.x;
28+
mov.u32 $1, %ntid.y;
29+
mov.u32 $2, %ntid.z;
30+
""",
31+
"=r,=r,=r",
32+
[],
33+
dtype=(tl.uint32, tl.uint32, tl.uint32),
34+
is_pure=True,
35+
pack=1,
36+
)
37+
38+
39+
@triton.jit
40+
def _get_flat_tid(): # noqa: ANN202
41+
tid_x, tid_y, tid_z = _get_tid()
42+
ntid_x, ntid_y, _ = _get_ntid()
43+
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
44+
45+
46+
@triton.jit
47+
def _get_flat_bid(): # noqa: ANN202
48+
return (
49+
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
50+
+ tl.program_id(1) * tl.num_programs(0)
51+
+ tl.program_id(0)
52+
)
53+
54+
55+
@triton.jit
56+
def _send_signal(addrs, sem: tl.constexpr) -> None: # noqa: ANN001
57+
tl.inline_asm_elementwise(
58+
f"""
59+
{{
60+
.reg .u32 %tmp32_<1>;
61+
.reg .pred %p<1>;
62+
63+
send_signal:
64+
atom.global.{sem}.sys.cas.b32 %tmp32_0, [$1], 0, 1;
65+
setp.eq.u32 %p0, %tmp32_0, 0;
66+
@!%p0 bra send_signal;
67+
}}
68+
""",
69+
"=r, l",
70+
[addrs],
71+
dtype=addrs.dtype,
72+
is_pure=False,
73+
pack=1,
74+
)
75+
76+
77+
@triton.jit
78+
def _wait_signal(addrs, sem: tl.constexpr) -> None: # noqa: ANN001
79+
tl.inline_asm_elementwise(
80+
f"""
81+
{{
82+
.reg .u32 %tmp32_<1>;
83+
.reg .pred %p<1>;
84+
85+
wait_signal:
86+
atom.global.sys.{sem}.cas.b32 %tmp32_0, [$1], 1, 0;
87+
setp.eq.u32 %p0, %tmp32_0, 1;
88+
@!%p0 bra wait_signal;
89+
}}
90+
""",
91+
"=r, l",
92+
[addrs],
93+
dtype=tl.int32,
94+
is_pure=False,
95+
pack=1,
96+
)
97+
98+
99+
@triton.jit
100+
def symm_mem_sync(
101+
signal_pad_ptrs, # noqa: ANN001
102+
block_id, # noqa: ANN001
103+
rank: tl.constexpr,
104+
world_size: tl.constexpr,
105+
hasPreviousMemAccess: tl.constexpr = False, # pyrefly: ignore[bad-function-definition]
106+
hasSubsequentMemAccess: tl.constexpr = False, # pyrefly: ignore[bad-function-definition]
107+
) -> None:
108+
"""
109+
Synchronizes blocks with matching block_id across participating devices.
110+
111+
Note: the function itself is not a system level barrier/fence. It is a
112+
building block for expressing different synchronization patterns.
113+
114+
Pattern 0: Ensures that all writes to symm_mem buffers from previous
115+
kernels across all devices are visible to the current kernel:
116+
117+
symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequentMemAccess=True)
118+
119+
Pattern 1: Ensures that all writes to symm_mem buffers from the current
120+
block are visible to all remote blocks with matching blockIdx:
121+
122+
symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=True)
123+
124+
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
125+
for writing by subsequent kernels across all devices.
126+
127+
symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=False)
128+
129+
CUDA graph friendliness:
130+
131+
This barrier operates through atomic operations on a zero-filled signal
132+
pad, which resets to a zero-filled state after each successful
133+
synchronization. This design eliminates the need for incrementing a
134+
flag from host.
135+
"""
136+
if block_id is None:
137+
block_id = _get_flat_bid()
138+
flat_tid = _get_flat_tid()
139+
140+
remote_ranks = tl.arange(0, world_size)
141+
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
142+
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
143+
tl.pointer_type(tl.uint32)
144+
)
145+
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
146+
147+
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
148+
tl.pointer_type(tl.uint32)
149+
)
150+
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
151+
152+
if hasPreviousMemAccess:
153+
tl.debug_barrier()
154+
155+
if flat_tid < world_size:
156+
_send_signal(send_addrs, "release" if hasPreviousMemAccess else "relaxed")
157+
_wait_signal(wait_addrs, "acquire" if hasSubsequentMemAccess else "relaxed")
158+
159+
if hasSubsequentMemAccess:
160+
tl.debug_barrier()

scripts/lint_examples_main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def main() -> int:
2424
for filename in sys.argv[1:]:
2525
if not filename.startswith("examples/") or not filename.endswith(".py"):
2626
continue
27+
if Path(filename).name in ["__init__.py", "utils.py"]:
28+
continue
2729
if not has_main_function(filename):
2830
print(f"{filename} is missing a main() function.")
2931
failed = True

0 commit comments

Comments
 (0)