Skip to content

Conversation

@davided0
Copy link

@davided0 davided0 commented Nov 24, 2025

Summary

Fuse w1 and w3 linear layers in the FeedForward module into a single w13 layer using horizontal fusion, reducing CUDA kernel launches and improving throughput by up to 24%.

Motivation

The LLaMA FeedForward module computes silu(w1(x)) * w3(x), requiring two independent matrix multiplications w1(x), w3(x) on the same input. The number of GEMM calls is reduced by concatenating w1 and w3 weights into a single w13 layer.

This mirrors the existing horizontal fusion optimization in the Attention module.

Changes

  • Combined w1 and w3 into a single w13 linear layer (2 * intermediate_size output features)
  • Added _merge_w1_w3 state dict hook for backward compatibility with existing checkpoints
  • Updated forward pass: single matmul → chunk() to split output

Performance Analysis

Kernel Reduction

The number of nn.linear calls in feed_forward is reduced from 192 to 128 (33% reduction, as expected). This was measured with the following command:

  • TORCH_LOGS="graph_code" python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --compile --quantization int8wo --write_result benchmark_results.txt | grep feed_forward | grep -c "nn.linear("

The reduction of nn.linear calls is further confirmed by the FX graphs:

Before (2 separate matmuls):

mm: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute)      # w1
mm_1: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute_1)  # w3

After (1 fused matmul + split):

mm: "f32[1, 22016]" = torch.ops.aten.mm.default(primals_2, permute)  # w13
split = torch.ops.aten.split.Tensor(mm, 11008, -1)                   # chunk
New FX graph
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[22016, 4096]", primals_2: "f32[1, 4096]", primals_3: "f32[4096, 11008]"):
         # File: /workspace/pytorch-ao/torchao/_models/llama/model.py:497 in forward, code: x1, x3 = self.w13(x).chunk(2, dim=-1)
        permute: "f32[4096, 22016]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        mm: "f32[1, 22016]" = torch.ops.aten.mm.default(primals_2, permute);  permute = None
        split = torch.ops.aten.split.Tensor(mm, 11008, -1);  mm = None
        getitem: "f32[1, 11008]" = split[0]
        getitem_1: "f32[1, 11008]" = split[1];  split = None
        
         # File: /workspace/pytorch-ao/torchao/_models/llama/model.py:498 in forward, code: return self.w2(F.silu(x1) * x3)
        sigmoid: "f32[1, 11008]" = torch.ops.aten.sigmoid.default(getitem)
        mul: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(getitem, sigmoid);  sigmoid = None
        mul_1: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mul, getitem_1);  mul = None
        permute_1: "f32[11008, 4096]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
        mm_1: "f32[1, 4096]" = torch.ops.aten.mm.default(mul_1, permute_1)
        permute_4: "f32[4096, 11008]" = torch.ops.aten.permute.default(permute_1, [1, 0]);  permute_1 = None
        return (mm_1, primals_2, getitem, getitem_1, mul_1, permute_4)
Original FX graph
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[11008, 4096]", primals_2: "f32[1, 4096]", primals_3: "f32[11008, 4096]", primals_4: "f32[4096, 11008]"):
         # File: /workspace/pytorch-ao/torchao/_models/llama/model.py:486 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
        permute: "f32[4096, 11008]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        mm: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute);  permute = None
        sigmoid: "f32[1, 11008]" = torch.ops.aten.sigmoid.default(mm)
        mul: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mm, sigmoid);  sigmoid = None
        permute_1: "f32[4096, 11008]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
        mm_1: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute_1);  permute_1 = None
        mul_1: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mul, mm_1);  mul = None
        permute_2: "f32[11008, 4096]" = torch.ops.aten.permute.default(primals_4, [1, 0]);  primals_4 = None
        mm_2: "f32[1, 4096]" = torch.ops.aten.mm.default(mul_1, permute_2)
        permute_5: "f32[4096, 11008]" = torch.ops.aten.permute.default(permute_2, [1, 0]);  permute_2 = None
        return (mm_2, primals_2, mm, mm_1, mul_1, permute_5)
Driver
import torch
from torch._inductor import config
config.trace.enabled = True

from model import FeedForward, ModelArgs

args = ModelArgs()
model = FeedForward(args)
model = torch.compile(model, fullgraph=True)

x = torch.rand(size=(1, args.dim))
out = model(x)

Benchmark Results

Speedup calculated as new_tok/s / baseline_tok/s using commands from torchao/_models/llama/benchmarks.sh. The other columns are computed similarly.

  • GPU: A100-SXM4-40GB
  • CUDA: 11.8
  • PyTorch: 2.6.0+cu118

Notable improvements:

Speedup (tok/s) Speedup (tok/s_decode) Speedup (ttft) Speedup (mem/s) Speedup (peak_mem) Params
1.236358749 1.24111364 0.9433497537 1.236358051 1.030627871 quant: fp6 sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.119138562 1.122922546 0.9755244755 1.119154914 1 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.118612397 1.122824894 0.9876977153 1.118619349 1 quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8

Most configurations show neutral performance (within ±1%), with significant gains in specific quantization scenarios.

Full benchmark table
Speedup (tok/s) Speedup (tok/s_decode) Speedup (ttft) Speedup (mem/s) Speedup (peak_mem) Params
0.9911355002 0.9907590074 0.9910447761 0.9910958849 1 quant: None sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.087821543 1.083854358 0.8843557382 1.087794498 0.9229922992 quant: int8dq sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
0.9950999355 0.9947361902 0.9940828402 0.9950826696 0.8942093541 quant: int8wo sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.236358749 1.24111364 0.9433497537 1.236358051 1.030627871 quant: fp6 sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
0.9922655878 0.9922562478 1 0.9922623814 1 quant: None sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.044583045 1.041311755 0.9201552537 1.044735709 1.02365416 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.119138562 1.122922546 0.9755244755 1.119154914 1 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.032804737 1.032319894 0.9538461538 1.03283344 1 quant: fp6 sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
0.9941659721 0.9936612278 0.972972973 0.9941143342 0.9836065574 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.118612397 1.122824894 0.9876977153 1.118619349 1 quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.013262599 1.013150973 0.9518716578 1.013111394 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192
1.006341154 1.005653266 1.009287926 1.005828441 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization
1.000664894 1.000659196 0.9685279188 1.000531491 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization --linear_causal_mask
1.009274874 1.008375209 1 1.008704931 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384
1.003857281 1.002870813 0.9502572899 1.003917791 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384--kv_cache_quantization
1.003898635 1.003872217 0.9458544839 1.003636364 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384--kv_cache_quantization --linear_causal_mask
0.9970457903 0.9970631424 0.951285521 0.9973430427 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768
1.005033557 1.005 0.9988066826 1.005026248 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768--kv_cache_quantization
1.003367003 1.003344482 0.9929824561 1.003138662 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768--kv_cache_quantization --linear_causal_mask
1.002873563 1.002857143 0.9963753524 1.002870813 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536
1.003215434 1 0.9958491871 1.001283422 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536--kv_cache_quantization
1 1.003205128 0.9958085924 1.001286725 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536--kv_cache_quantization --linear_causal_mask
1.045706371 1.041081311 0.9073971079 1.045563549 1.023333333 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.053301512 1.043371723 0.8395337302 1.053222673 1 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 32 --top_k 200 --temperature 0.8
1.01754386 1.012869565 0.9205726613 1.017601432 1.010028653 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 128 --top_k 200 --temperature 0.8
0.999491353 0.9988133464 0.9824561404 0.9994782306 1 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.044860076 1.044827115 0.953164557 1.044879946 0.9806501548 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 32 --top_k 200 --temperature 0.8
1.094404394 1.096069869 0.9857336957 1.094362018 0.9884947267 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 128 --top_k 200 --temperature 0.8
0.9984371948 0.9967086434 0.9965533748 0.9985293352 1.006455234 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.032615385 1.040481928 0.9943841258 1.032706258 1 quant: int8dq sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.013203825 1.018600098 0.9963788301 1.01321817 1.015944541 quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.018356868 1.02598419 0.9936114993 1.018331352 1.028186275 quant: sparse-marlin sparse: semi-structured mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.float16 device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
Raw data

Related Issues

Fixes #606

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3380

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[llama] Use horizontal fusion trick from Attention for FeedForward

1 participant