Skip to content

[llama] Use horizontal fusion trick from Attention for FeedForward #606

@cpuhrsch

Description

@cpuhrsch

For the Attention module we can concatenate the weights and do one instead of three GEMMs for the input to gain a speedup, because each GEMM will be applied to the same input.

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

and
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

I suspect we can do the exact same thing for FeedFoward

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))

Task:
Implement the above trick and rerun the benchmarks to show gains. If you don't have access to an A100, another (ideally similar) GPU is fine too as a proxy. Also, if you can, try to confirm via a trace that indeed two GEMMs have been turned into one.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions