-
Notifications
You must be signed in to change notification settings - Fork 376
Description
For the Attention module we can concatenate the weights and do one instead of three GEMMs for the input to gain a speedup, because each GEMM will be applied to the same input.
ao/torchao/_models/llama/model.py
Lines 220 to 225 in 22d6f97
| def load_hook(self, state_dict, prefix, *args): | |
| if prefix + "wq.weight" in state_dict: | |
| wq = state_dict.pop(prefix + "wq.weight") | |
| wk = state_dict.pop(prefix + "wk.weight") | |
| wv = state_dict.pop(prefix + "wv.weight") | |
| state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) |
and
ao/torchao/_models/llama/model.py
Lines 230 to 231 in 22d6f97
| kv_size = self.n_local_heads * self.head_dim | |
| q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
I suspect we can do the exact same thing for FeedFoward
ao/torchao/_models/llama/model.py
Lines 262 to 263 in 22d6f97
| def forward(self, x: Tensor) -> Tensor: | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
Task:
Implement the above trick and rerun the benchmarks to show gains. If you don't have access to an A100, another (ideally similar) GPU is fine too as a proxy. Also, if you can, try to confirm via a trace that indeed two GEMMs have been turned into one.