Skip to content

Commit 83a8c2d

Browse files
trieuatfacebook-github-bot
authored andcommitted
jagged_dense_bmm (#1126)
Summary: Add an example of jagged dense bmm. Differential Revision: D84082652
1 parent 2644d0a commit 83a8c2d

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

examples/jagged_dense_bmm.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional, Tuple
4+
5+
import helion
6+
import helion.language as hl
7+
8+
import torch
9+
from helion._testing import run_example
10+
11+
"""
12+
---jagged_dense_bmm---
13+
seq_offsets : [B + 1] # B is batch size
14+
jagged : [L, D] # L is sum of sequence lengths, D is embedding dimension
15+
dense : [B, D, K] # K is output dimension
16+
bias : [B, K] # optional bias
17+
"""
18+
19+
20+
@helion.kernel()
21+
def jagged_dense_bmm(
22+
seq_offsets: torch.Tensor,
23+
jagged: torch.Tensor,
24+
dense: torch.Tensor,
25+
bias: Optional[torch.Tensor] = None,
26+
) -> torch.Tensor:
27+
L, D = jagged.shape
28+
B, D, K = dense.shape
29+
dtype = torch.promote_types(jagged.dtype, dense.dtype)
30+
device = jagged.device
31+
32+
jagged = jagged.view(-1) # flattening to [L * D]
33+
# Allocate output tensor and flatten to 1D
34+
output = torch.empty((L, K), dtype=dtype, device=device).view(-1)
35+
for tile_b in hl.tile(B):
36+
starts = seq_offsets[tile_b]
37+
ends = seq_offsets[tile_b.index + 1]
38+
seq_len = ends - starts
39+
max_seq_len = seq_len.amax()
40+
41+
for tile_len in hl.tile(0, max_seq_len):
42+
mask = tile_len.index[None, :] < seq_len[:, None]
43+
jagged_indices = starts[:, None] + tile_len.index[None, :]
44+
45+
for tile_k in hl.tile(0, K):
46+
acc = hl.zeros([tile_b, tile_len, tile_k], dtype=dtype, device=device)
47+
for tile_d in hl.tile(0, D):
48+
jagged_data = hl.load(
49+
jagged,
50+
[jagged_indices[:, :, None] * D + tile_d.index[None, None, :]],
51+
extra_mask=mask[:, :, None] & (tile_d.index < D)[None, None, :],
52+
) # [tile_b, tile_len, tile_d]
53+
dense_data = dense[tile_b, tile_d, tile_k]
54+
55+
acc = acc + torch.matmul(
56+
jagged_data, dense_data
57+
) # [tile_b, tile_len, tile_k]
58+
59+
if bias is not None:
60+
bias_data = bias[tile_b, tile_k] # [tile_b, tile_k]
61+
# [tile_b, tile_len, tile_k] + [tile_b, 1, tile_k] -> [tile_b, tile_len, tile_k]
62+
acc = acc + bias_data.unsqueeze(1)
63+
64+
hl.store(
65+
output,
66+
[jagged_indices[:, :, None] * K + tile_k.index[None, None, :]],
67+
acc,
68+
extra_mask=mask[:, :, None],
69+
)
70+
return output.reshape(L, K)
71+
72+
73+
def jagged_dense_bmm_reference(
74+
seq_offsets: torch.Tensor,
75+
jagged: torch.Tensor,
76+
dense: torch.Tensor,
77+
bias: Optional[torch.Tensor] = None,
78+
) -> torch.Tensor:
79+
L, D = jagged.shape
80+
B, _, K = dense.shape
81+
82+
# Allocate output tensor
83+
ref_output = torch.empty((L, K), dtype=jagged.dtype, device=jagged.device)
84+
85+
# Process each example in the batch
86+
for i in range(B):
87+
seq_start = seq_offsets[i].item()
88+
seq_end = seq_offsets[i + 1].item()
89+
90+
if seq_start < seq_end: # Non-empty sequence
91+
seq_data = jagged[seq_start:seq_end] # [seq_len, D]
92+
93+
# Matrix multiplication: [seq_len, D] @ [D, K] -> [seq_len, K]
94+
result = torch.matmul(seq_data, dense[i])
95+
96+
# Add bias if provided
97+
if bias is not None:
98+
result = result + bias[i].unsqueeze(0)
99+
100+
# Store result
101+
ref_output[seq_start:seq_end] = result
102+
return ref_output
103+
104+
105+
def random_input(
106+
D: int = 4,
107+
K: int = 5,
108+
batch_size: int = 3,
109+
max_seq_len: int = 3,
110+
dtype: torch.dtype = torch.float32,
111+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
112+
lengths = torch.randint(
113+
max_seq_len + 1, size=(batch_size,), device=torch.device("cuda")
114+
)
115+
seq_offsets = torch.zeros(
116+
(batch_size + 1,), dtype=torch.int64, device=torch.device("cuda")
117+
)
118+
seq_offsets[1:] = torch.cumsum(lengths, dim=0)
119+
jagged_size = int(seq_offsets[-1].item())
120+
jagged = (
121+
torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda"))
122+
.uniform_(-1.0, 1.0)
123+
.requires_grad_()
124+
)
125+
dense = (
126+
torch.empty((batch_size, D, K), dtype=dtype, device=torch.device("cuda"))
127+
.uniform_(-1.0, 1.0)
128+
.requires_grad_()
129+
)
130+
bias = (
131+
torch.empty((batch_size, K), dtype=dtype, device=torch.device("cuda"))
132+
.uniform_(-1.0, 1.0)
133+
.requires_grad_()
134+
)
135+
return seq_offsets, jagged, dense, bias
136+
137+
138+
def main() -> None:
139+
seq_offsets, jagged, dense, bias = random_input(
140+
D=34, K=24, batch_size=23, max_seq_len=37, dtype=torch.float32
141+
)
142+
run_example(
143+
jagged_dense_bmm, jagged_dense_bmm_reference, (seq_offsets, jagged, dense, bias)
144+
)
145+
146+
147+
if __name__ == "__main__":
148+
main()

test/test_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,20 @@ def test_jagged_dense_add(self):
701701
)
702702
)
703703

704+
def test_jagged_dense_bmm(self):
705+
mod = import_path(EXAMPLES_DIR / "jagged_dense_bmm.py")
706+
seq_offsets, jagged, dense, bias = mod.random_input(
707+
D=32, K=24, batch_size=16, max_seq_len=32, dtype=torch.float32
708+
)
709+
args = (seq_offsets, jagged, dense, bias)
710+
self.assertExpectedJournal(
711+
check_example(
712+
"jagged_dense_bmm",
713+
args,
714+
mod.jagged_dense_bmm_reference(*args),
715+
)
716+
)
717+
704718
@skipIfRefEager("Test has skip_accuracy=True and doesn't call assert_close")
705719
def test_moe_matmul_ogs(self):
706720
mod = import_path(EXAMPLES_DIR / "moe_matmul_ogs.py")

0 commit comments

Comments
 (0)