中文 | English
通过一个完整的 end-to-end 示例,演示如何使用 LLM agent 开发和优化一个 Triton kernel。
Target: Fused LayerNorm + GELU kernel (common in Transformer models)
# Reference implementation (PyTorch)
def fused_layernorm_gelu(x, weight, bias, eps=1e-5):
"""
Input: x [M, N], weight [N], bias [N]
Output: y [M, N]
y = GELU(LayerNorm(x, weight, bias))
"""
# LayerNorm
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, unbiased=False, keepdim=True)
x_norm = (x - mean) / torch.sqrt(var + eps)
x_ln = x_norm * weight + bias
# GELU
y = x_ln * 0.5 * (1.0 + torch.tanh(0.7978845608 * (x_ln + 0.044715 * x_ln ** 3)))
return yTask: Write a Triton kernel that fuses LayerNorm + GELU activation.
Specification:
- Input: x [M, N] float16, weight [N] float32, bias [N] float32
- Output: y [M, N] float16
- Operation: y = GELU(LayerNorm(x, weight, bias))
- LayerNorm along the last dimension (N)
- Use approximate GELU (tanh approximation)
Requirements:
- Each program instance processes one row (M dimension)
- Handle arbitrary N (not necessarily power of 2)
- Accumulate in float32 for numerical stability
- Include proper masking for non-aligned N
Target: NVIDIA H100, N typically 4096-12288
import triton
import triton.language as tl
@triton.jit
def fused_layernorm_gelu_kernel(
x_ptr, weight_ptr, bias_ptr, y_ptr,
M, N, eps,
stride_xm, stride_xn,
BLOCK_N: tl.constexpr,
):
row_idx = tl.program_id(0)
# Pointers for this row
row_start = x_ptr + row_idx * stride_xm
out_start = y_ptr + row_idx * stride_xm
# Pass 1: Compute mean
_mean = tl.zeros([BLOCK_N], dtype=tl.float32)
for off in range(0, N, BLOCK_N):
cols = off + tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
_mean += x
mean = tl.sum(_mean, axis=0) / N
# Pass 2: Compute variance
_var = tl.zeros([BLOCK_N], dtype=tl.float32)
for off in range(0, N, BLOCK_N):
cols = off + tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
diff = x - mean
_var += diff * diff
var = tl.sum(_var, axis=0) / N
rstd = 1.0 / tl.sqrt(var + eps)
# Pass 3: Normalize + GELU
for off in range(0, N, BLOCK_N):
cols = off + tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
w = tl.load(weight_ptr + cols, mask=mask, other=1.0).to(tl.float32)
b = tl.load(bias_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# LayerNorm
x_norm = (x - mean) * rstd
x_ln = x_norm * w + b
# Approximate GELU
y = 0.5 * x_ln * (1.0 + tl.math.tanh(0.7978845608 * (x_ln + 0.044715 * x_ln * x_ln * x_ln)))
tl.store(out_start + cols * stride_xn, y.to(tl.float16), mask=mask)import torch
def test_correctness():
M, N = 1024, 4096
x = torch.randn(M, N, dtype=torch.float16, device="cuda")
weight = torch.randn(N, dtype=torch.float32, device="cuda")
bias = torch.randn(N, dtype=torch.float32, device="cuda")
# Reference
ref = fused_layernorm_gelu(x.float(), weight, bias).half()
# Generated kernel
y = torch.empty_like(x)
BLOCK_N = triton.next_power_of_2(min(N, 4096))
grid = (M,)
fused_layernorm_gelu_kernel[grid](
x, weight, bias, y, M, N, 1e-5,
x.stride(0), x.stride(1),
BLOCK_N=BLOCK_N,
)
# Check
max_err = (ref - y).abs().max().item()
print(f"Max absolute error: {max_err:.6f}")
assert torch.allclose(ref, y, rtol=1e-2, atol=1e-2), f"Failed! Max error: {max_err}"
print("Correctness: PASS")
test_correctness()ref_ms = triton.testing.do_bench(lambda: fused_layernorm_gelu(x.float(), weight, bias))
gen_ms = triton.testing.do_bench(lambda: fused_layernorm_gelu_kernel[grid](
x, weight, bias, y, M, N, 1e-5, x.stride(0), x.stride(1), BLOCK_N=BLOCK_N))
print(f"Reference: {ref_ms:.3f} ms")
print(f"Generated: {gen_ms:.3f} ms")
print(f"Speedup: {ref_ms/gen_ms:.2f}x")典型结果: 初始生成的 kernel 可能比参考实现慢 (3-pass over data) 或只有 ~1.2x speedup.
Feed profiling results back to LLM:
Current kernel performance:
- 3 passes over the data (mean, variance, normalize+gelu)
- Memory-bound (N=4096, each row fits in L2 but 3x read)
- Low compute intensity
Optimization request:
Can you reduce to 2 passes or single pass?
For N <= BLOCK_N, all data fits in registers.
@triton.jit
def fused_layernorm_gelu_v2(
x_ptr, weight_ptr, bias_ptr, y_ptr,
M, N, eps,
stride_xm, stride_xn,
BLOCK_N: tl.constexpr,
):
row_idx = tl.program_id(0)
row_start = x_ptr + row_idx * stride_xm
out_start = y_ptr + row_idx * stride_xm
# Single pass: compute mean and variance using Welford's algorithm
_mean = tl.zeros([BLOCK_N], dtype=tl.float32)
_m2 = tl.zeros([BLOCK_N], dtype=tl.float32)
_count = tl.zeros([BLOCK_N], dtype=tl.float32)
for off in range(0, N, BLOCK_N):
cols = off + tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
_count += mask.to(tl.float32)
delta = x - _mean
_mean += tl.where(mask, delta / tl.maximum(_count, 1.0), 0.0)
delta2 = x - _mean
_m2 += tl.where(mask, delta * delta2, 0.0)
mean = tl.sum(_mean, axis=0) / N
var = tl.sum(_m2, axis=0) / N
rstd = 1.0 / tl.sqrt(var + eps)
# Pass 2: Normalize + GELU (same as before)
for off in range(0, N, BLOCK_N):
cols = off + tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
w = tl.load(weight_ptr + cols, mask=mask, other=1.0).to(tl.float32)
b = tl.load(bias_ptr + cols, mask=mask, other=0.0).to(tl.float32)
x_norm = (x - mean) * rstd
x_ln = x_norm * w + b
y = 0.5 * x_ln * (1.0 + tl.math.tanh(0.7978845608 * (x_ln + 0.044715 * x_ln * x_ln * x_ln)))
tl.store(out_start + cols * stride_xn, y.to(tl.float16), mask=mask)v1 (3-pass): 0.45 ms
v2 (2-pass): 0.32 ms (1.4x faster)
PyTorch ref: 0.55 ms
Speedup vs PyTorch: 1.72x ✓
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 1024}, num_warps=4),
triton.Config({"BLOCK_N": 2048}, num_warps=8),
triton.Config({"BLOCK_N": 4096}, num_warps=8),
],
key=["N"],
)
@triton.jit
def fused_layernorm_gelu_final(...):
...- LLM 可以生成功能正确的初始 kernel,但通常不是最优的
- Profiling feedback 是关键 — 让 LLM 理解具体的性能瓶颈
- 迭代优化比 one-shot 生成更有效
- 数值验证必须自动化 — 每次修改后都要验证正确性
- 最终仍需要人工审核 — 确认优化策略的合理性