Davided0/feedforward horizontal fusion #3380
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Fuse w1 and w3 linear layers in the FeedForward module into a single w13 layer using horizontal fusion, reducing CUDA kernel launches and improving throughput by up to 24%.
Motivation
The LLaMA FeedForward module computes
silu(w1(x)) * w3(x), requiring two independent matrix multiplicationsw1(x),w3(x)on the same input. The number of GEMM calls is reduced by concatenating w1 and w3 weights into a single w13 layer.This mirrors the existing horizontal fusion optimization in the Attention module.
Changes
w1andw3into a singlew13linear layer (2 * intermediate_sizeoutput features)_merge_w1_w3state dict hook for backward compatibility with existing checkpointschunk()to split outputPerformance Analysis
Kernel Reduction
The number of
nn.linearcalls infeed_forwardis reduced from 192 to 128 (33% reduction, as expected). This was measured with the following command:TORCH_LOGS="graph_code" python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --compile --quantization int8wo --write_result benchmark_results.txt | grep feed_forward | grep -c "nn.linear("The reduction of
nn.linearcalls is further confirmed by the FX graphs:Before (2 separate matmuls):
After (1 fused matmul + split):
New FX graph
Original FX graph
Driver
Benchmark Results
Speedup calculated as
new_tok/s / baseline_tok/susing commands fromtorchao/_models/llama/benchmarks.sh. The other columns are computed similarly.Notable improvements:
Most configurations show neutral performance (within ±1%), with significant gains in specific quantization scenarios.
Full benchmark table
Raw data
Related Issues
Fixes #606