Skip to content

Latest commit

 

History

History
260 lines (204 loc) · 7.66 KB

File metadata and controls

260 lines (204 loc) · 7.66 KB

中文 | English

Practical Tutorial: LLM Agent for Triton Kernel Development

Goal

通过一个完整的 end-to-end 示例,演示如何使用 LLM agent 开发和优化一个 Triton kernel。

Target: Fused LayerNorm + GELU kernel (common in Transformer models)

Step 1: Define the Task

# Reference implementation (PyTorch)
def fused_layernorm_gelu(x, weight, bias, eps=1e-5):
    """
    Input: x [M, N], weight [N], bias [N]
    Output: y [M, N]
    y = GELU(LayerNorm(x, weight, bias))
    """
    # LayerNorm
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    x_norm = (x - mean) / torch.sqrt(var + eps)
    x_ln = x_norm * weight + bias

    # GELU
    y = x_ln * 0.5 * (1.0 + torch.tanh(0.7978845608 * (x_ln + 0.044715 * x_ln ** 3)))
    return y

Step 2: Prompt the LLM

Task: Write a Triton kernel that fuses LayerNorm + GELU activation.

Specification:
- Input: x [M, N] float16, weight [N] float32, bias [N] float32
- Output: y [M, N] float16
- Operation: y = GELU(LayerNorm(x, weight, bias))
- LayerNorm along the last dimension (N)
- Use approximate GELU (tanh approximation)

Requirements:
- Each program instance processes one row (M dimension)
- Handle arbitrary N (not necessarily power of 2)
- Accumulate in float32 for numerical stability
- Include proper masking for non-aligned N

Target: NVIDIA H100, N typically 4096-12288

Step 3: LLM Generates Initial Kernel

import triton
import triton.language as tl

@triton.jit
def fused_layernorm_gelu_kernel(
    x_ptr, weight_ptr, bias_ptr, y_ptr,
    M, N, eps,
    stride_xm, stride_xn,
    BLOCK_N: tl.constexpr,
):
    row_idx = tl.program_id(0)

    # Pointers for this row
    row_start = x_ptr + row_idx * stride_xm
    out_start = y_ptr + row_idx * stride_xm

    # Pass 1: Compute mean
    _mean = tl.zeros([BLOCK_N], dtype=tl.float32)
    for off in range(0, N, BLOCK_N):
        cols = off + tl.arange(0, BLOCK_N)
        mask = cols < N
        x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
        _mean += x
    mean = tl.sum(_mean, axis=0) / N

    # Pass 2: Compute variance
    _var = tl.zeros([BLOCK_N], dtype=tl.float32)
    for off in range(0, N, BLOCK_N):
        cols = off + tl.arange(0, BLOCK_N)
        mask = cols < N
        x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
        diff = x - mean
        _var += diff * diff
    var = tl.sum(_var, axis=0) / N
    rstd = 1.0 / tl.sqrt(var + eps)

    # Pass 3: Normalize + GELU
    for off in range(0, N, BLOCK_N):
        cols = off + tl.arange(0, BLOCK_N)
        mask = cols < N
        x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
        w = tl.load(weight_ptr + cols, mask=mask, other=1.0).to(tl.float32)
        b = tl.load(bias_ptr + cols, mask=mask, other=0.0).to(tl.float32)

        # LayerNorm
        x_norm = (x - mean) * rstd
        x_ln = x_norm * w + b

        # Approximate GELU
        y = 0.5 * x_ln * (1.0 + tl.math.tanh(0.7978845608 * (x_ln + 0.044715 * x_ln * x_ln * x_ln)))

        tl.store(out_start + cols * stride_xn, y.to(tl.float16), mask=mask)

Step 4: Verify Correctness

import torch

def test_correctness():
    M, N = 1024, 4096
    x = torch.randn(M, N, dtype=torch.float16, device="cuda")
    weight = torch.randn(N, dtype=torch.float32, device="cuda")
    bias = torch.randn(N, dtype=torch.float32, device="cuda")

    # Reference
    ref = fused_layernorm_gelu(x.float(), weight, bias).half()

    # Generated kernel
    y = torch.empty_like(x)
    BLOCK_N = triton.next_power_of_2(min(N, 4096))
    grid = (M,)
    fused_layernorm_gelu_kernel[grid](
        x, weight, bias, y, M, N, 1e-5,
        x.stride(0), x.stride(1),
        BLOCK_N=BLOCK_N,
    )

    # Check
    max_err = (ref - y).abs().max().item()
    print(f"Max absolute error: {max_err:.6f}")
    assert torch.allclose(ref, y, rtol=1e-2, atol=1e-2), f"Failed! Max error: {max_err}"
    print("Correctness: PASS")

test_correctness()

Step 5: Profile and Identify Bottleneck

ref_ms = triton.testing.do_bench(lambda: fused_layernorm_gelu(x.float(), weight, bias))
gen_ms = triton.testing.do_bench(lambda: fused_layernorm_gelu_kernel[grid](
    x, weight, bias, y, M, N, 1e-5, x.stride(0), x.stride(1), BLOCK_N=BLOCK_N))

print(f"Reference: {ref_ms:.3f} ms")
print(f"Generated: {gen_ms:.3f} ms")
print(f"Speedup: {ref_ms/gen_ms:.2f}x")

典型结果: 初始生成的 kernel 可能比参考实现 (3-pass over data) 或只有 ~1.2x speedup.

Step 6: LLM Optimizes (Iteration 1)

Feed profiling results back to LLM:

Current kernel performance:
- 3 passes over the data (mean, variance, normalize+gelu)
- Memory-bound (N=4096, each row fits in L2 but 3x read)
- Low compute intensity

Optimization request:
Can you reduce to 2 passes or single pass?
For N <= BLOCK_N, all data fits in registers.

LLM 优化建议: 减少到 2-pass (Welford's online algorithm)

@triton.jit
def fused_layernorm_gelu_v2(
    x_ptr, weight_ptr, bias_ptr, y_ptr,
    M, N, eps,
    stride_xm, stride_xn,
    BLOCK_N: tl.constexpr,
):
    row_idx = tl.program_id(0)
    row_start = x_ptr + row_idx * stride_xm
    out_start = y_ptr + row_idx * stride_xm

    # Single pass: compute mean and variance using Welford's algorithm
    _mean = tl.zeros([BLOCK_N], dtype=tl.float32)
    _m2 = tl.zeros([BLOCK_N], dtype=tl.float32)
    _count = tl.zeros([BLOCK_N], dtype=tl.float32)

    for off in range(0, N, BLOCK_N):
        cols = off + tl.arange(0, BLOCK_N)
        mask = cols < N
        x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)

        _count += mask.to(tl.float32)
        delta = x - _mean
        _mean += tl.where(mask, delta / tl.maximum(_count, 1.0), 0.0)
        delta2 = x - _mean
        _m2 += tl.where(mask, delta * delta2, 0.0)

    mean = tl.sum(_mean, axis=0) / N
    var = tl.sum(_m2, axis=0) / N
    rstd = 1.0 / tl.sqrt(var + eps)

    # Pass 2: Normalize + GELU (same as before)
    for off in range(0, N, BLOCK_N):
        cols = off + tl.arange(0, BLOCK_N)
        mask = cols < N
        x = tl.load(row_start + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
        w = tl.load(weight_ptr + cols, mask=mask, other=1.0).to(tl.float32)
        b = tl.load(bias_ptr + cols, mask=mask, other=0.0).to(tl.float32)

        x_norm = (x - mean) * rstd
        x_ln = x_norm * w + b

        y = 0.5 * x_ln * (1.0 + tl.math.tanh(0.7978845608 * (x_ln + 0.044715 * x_ln * x_ln * x_ln)))

        tl.store(out_start + cols * stride_xn, y.to(tl.float16), mask=mask)

Step 7: Verify and Benchmark Again

v1 (3-pass): 0.45 ms
v2 (2-pass): 0.32 ms  (1.4x faster)
PyTorch ref: 0.55 ms

Speedup vs PyTorch: 1.72x  ✓

Step 8: Final Autotune

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_N": 1024}, num_warps=4),
        triton.Config({"BLOCK_N": 2048}, num_warps=8),
        triton.Config({"BLOCK_N": 4096}, num_warps=8),
    ],
    key=["N"],
)
@triton.jit
def fused_layernorm_gelu_final(...):
    ...

Takeaways

  1. LLM 可以生成功能正确的初始 kernel,但通常不是最优的
  2. Profiling feedback 是关键 — 让 LLM 理解具体的性能瓶颈
  3. 迭代优化比 one-shot 生成更有效
  4. 数值验证必须自动化 — 每次修改后都要验证正确性
  5. 最终仍需要人工审核 — 确认优化策略的合理性

Reference