Skip to content

Commit d83ba7b

Browse files
trieuatfacebook-github-bot
authored andcommitted
jagged_dense_bmm (#1126)
Summary: Add an example of jagged dense bmm. Differential Revision: D84082652
1 parent 2644d0a commit d83ba7b

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

examples/jagged_dense_bmm.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""Jagged-Dense Batch Matrix Multiplication using Helion.
2+
3+
This module implements jagged-dense batch matrix multiplication (BMM)
4+
operation using Helion.
5+
6+
The operation performs batch matrix multiplication between jagged (variable-length)
7+
sequences and dense matrices with optional bias addition. Unlike standard batch
8+
matrix multiplication, the jagged format allows different sequence lengths within
9+
the same batch, making it memory-efficient for variable-length inputs.
10+
11+
Tensor Shapes:
12+
seq_offsets : [B + 1] - Cumulative offsets defining sequence boundaries
13+
where B is the batch size
14+
jagged : [L, D] - Jagged input tensor where L is the total sum of
15+
sequence lengths and D is the embedding dimension
16+
dense : [B, D, K] - Dense weight matrices where K is the output dimension
17+
bias : [B, K] - Optional bias vectors
18+
output : [L, K] - Result with same jagged structure as input
19+
20+
Example:
21+
For a batch of 3 sequences with lengths [2, 1, 3]:
22+
- seq_offsets = [0, 2, 3, 6]
23+
- jagged shape = [6, D] (concatenated sequences)
24+
- dense shape = [3, D, K]
25+
- output shape = [6, K]
26+
27+
Usage:
28+
>>> seq_offsets, jagged, dense, bias = random_input(D=32, K=64, batch_size=16)
29+
>>> output = jagged_dense_bmm(seq_offsets, jagged, dense, bias)
30+
"""
31+
from __future__ import annotations
32+
33+
from typing import Optional, Tuple
34+
35+
import helion
36+
import helion.language as hl
37+
38+
import torch
39+
from helion._testing import run_example
40+
41+
42+
@helion.kernel()
43+
def jagged_dense_bmm(
44+
seq_offsets: torch.Tensor,
45+
jagged: torch.Tensor,
46+
dense: torch.Tensor,
47+
bias: Optional[torch.Tensor] = None,
48+
) -> torch.Tensor:
49+
L, D = jagged.shape
50+
B, D, K = dense.shape
51+
dtype = torch.promote_types(jagged.dtype, dense.dtype)
52+
device = jagged.device
53+
54+
jagged = jagged.view(-1) # flattening to [L * D]
55+
# Allocate output tensor and flatten to 1D
56+
output = torch.empty((L, K), dtype=dtype, device=device).view(-1)
57+
for tile_b in hl.tile(B):
58+
starts = seq_offsets[tile_b]
59+
ends = seq_offsets[tile_b.index + 1]
60+
seq_len = ends - starts
61+
max_seq_len = seq_len.amax()
62+
63+
for tile_len in hl.tile(0, max_seq_len):
64+
mask = tile_len.index[None, :] < seq_len[:, None]
65+
jagged_indices = starts[:, None] + tile_len.index[None, :]
66+
67+
for tile_k in hl.tile(0, K):
68+
acc = hl.zeros([tile_b, tile_len, tile_k], dtype=dtype, device=device)
69+
for tile_d in hl.tile(0, D):
70+
jagged_data = hl.load(
71+
jagged,
72+
[jagged_indices[:, :, None] * D + tile_d.index[None, None, :]],
73+
extra_mask=mask[:, :, None] & (tile_d.index < D)[None, None, :],
74+
) # [tile_b, tile_len, tile_d]
75+
dense_data = dense[tile_b, tile_d, tile_k]
76+
77+
acc = acc + torch.matmul(
78+
jagged_data, dense_data
79+
) # [tile_b, tile_len, tile_k]
80+
81+
if bias is not None:
82+
bias_data = bias[tile_b, tile_k] # [tile_b, tile_k]
83+
# [tile_b, tile_len, tile_k] + [tile_b, 1, tile_k] -> [tile_b, tile_len, tile_k]
84+
acc = acc + bias_data.unsqueeze(1)
85+
86+
hl.store(
87+
output,
88+
[jagged_indices[:, :, None] * K + tile_k.index[None, None, :]],
89+
acc,
90+
extra_mask=mask[:, :, None],
91+
)
92+
return output.reshape(L, K)
93+
94+
95+
def jagged_dense_bmm_reference(
96+
seq_offsets: torch.Tensor,
97+
jagged: torch.Tensor,
98+
dense: torch.Tensor,
99+
bias: Optional[torch.Tensor] = None,
100+
) -> torch.Tensor:
101+
L, D = jagged.shape
102+
B, _, K = dense.shape
103+
104+
# Allocate output tensor
105+
ref_output = torch.empty((L, K), dtype=jagged.dtype, device=jagged.device)
106+
107+
# Process each example in the batch
108+
for i in range(B):
109+
seq_start = seq_offsets[i].item()
110+
seq_end = seq_offsets[i + 1].item()
111+
112+
if seq_start < seq_end: # Non-empty sequence
113+
seq_data = jagged[seq_start:seq_end] # [seq_len, D]
114+
115+
# Matrix multiplication: [seq_len, D] @ [D, K] -> [seq_len, K]
116+
result = torch.matmul(seq_data, dense[i])
117+
118+
# Add bias if provided
119+
if bias is not None:
120+
result = result + bias[i].unsqueeze(0)
121+
122+
# Store result
123+
ref_output[seq_start:seq_end] = result
124+
return ref_output
125+
126+
127+
def random_input(
128+
D: int = 4,
129+
K: int = 5,
130+
batch_size: int = 3,
131+
max_seq_len: int = 3,
132+
dtype: torch.dtype = torch.float32,
133+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
134+
lengths = torch.randint(
135+
max_seq_len + 1, size=(batch_size,), device=torch.device("cuda")
136+
)
137+
seq_offsets = torch.zeros(
138+
(batch_size + 1,), dtype=torch.int64, device=torch.device("cuda")
139+
)
140+
seq_offsets[1:] = torch.cumsum(lengths, dim=0)
141+
jagged_size = int(seq_offsets[-1].item())
142+
jagged = (
143+
torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda"))
144+
.uniform_(-1.0, 1.0)
145+
.requires_grad_()
146+
)
147+
dense = (
148+
torch.empty((batch_size, D, K), dtype=dtype, device=torch.device("cuda"))
149+
.uniform_(-1.0, 1.0)
150+
.requires_grad_()
151+
)
152+
bias = (
153+
torch.empty((batch_size, K), dtype=dtype, device=torch.device("cuda"))
154+
.uniform_(-1.0, 1.0)
155+
.requires_grad_()
156+
)
157+
return seq_offsets, jagged, dense, bias
158+
159+
160+
def main() -> None:
161+
seq_offsets, jagged, dense, bias = random_input(
162+
D=34, K=24, batch_size=23, max_seq_len=37, dtype=torch.float32
163+
)
164+
run_example(
165+
jagged_dense_bmm, jagged_dense_bmm_reference, (seq_offsets, jagged, dense, bias)
166+
)
167+
168+
169+
if __name__ == "__main__":
170+
main()

test/test_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,20 @@ def test_jagged_dense_add(self):
701701
)
702702
)
703703

704+
def test_jagged_dense_bmm(self):
705+
mod = import_path(EXAMPLES_DIR / "jagged_dense_bmm.py")
706+
seq_offsets, jagged, dense, bias = mod.random_input(
707+
D=32, K=24, batch_size=16, max_seq_len=32, dtype=torch.float32
708+
)
709+
args = (seq_offsets, jagged, dense, bias)
710+
self.assertExpectedJournal(
711+
check_example(
712+
"jagged_dense_bmm",
713+
args,
714+
mod.jagged_dense_bmm_reference(*args),
715+
)
716+
)
717+
704718
@skipIfRefEager("Test has skip_accuracy=True and doesn't call assert_close")
705719
def test_moe_matmul_ogs(self):
706720
mod = import_path(EXAMPLES_DIR / "moe_matmul_ogs.py")

0 commit comments

Comments
 (0)