Skip to content

Commit 27d52d6

Browse files
aviralgarg05rasbt
andauthored
Fix MHAEinsum weight dimension bug when d_in != d_out (#857) (#893)
* Fix MHAEinsum weight dimension bug when d_in != d_out (#857) Previously MHAEinsum initialized weight matrices with shape (d_out, d_in) and used inappropriate einsum notation, causing failures for non-square input-output dimensions. This commit corrects weight initialization to shape (d_in, d_out), updates einsum notation to 'bnd,do->bno', and adds three unit tests to verify parity across different d_in and d_out settings. All tests pass successfully. * use pytest * Update .gitignore --------- Co-authored-by: rasbt <[email protected]>
1 parent b1db33b commit 27d52d6

File tree

4 files changed

+80
-14
lines changed

4 files changed

+80
-14
lines changed

.github/workflows/basic-tests-linux-uv.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
run: |
4949
source .venv/bin/activate
5050
pytest setup/02_installing-python-libraries/tests.py
51+
pytest ch03/02_bonus_efficient-multihead-attention/tests/test_mha_implementations.py
5152
pytest ch04/01_main-chapter-code/tests.py
5253
pytest ch04/03_kv-cache/tests.py
5354
pytest ch05/01_main-chapter-code/tests.py

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,5 @@ cython_debug/
328328
# pixi environments
329329
.pixi
330330
*.egg-info
331+
332+

ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb

Lines changed: 14 additions & 14 deletions
Large diffs are not rendered by default.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from pathlib import Path
2+
import torch
3+
import pytest
4+
5+
6+
from llms_from_scratch.utils import import_definitions_from_notebook
7+
8+
9+
@pytest.fixture
10+
def nb_imports():
11+
nb_dir = Path(__file__).resolve().parents[1]
12+
mod = import_definitions_from_notebook(nb_dir, "mha-implementations.ipynb")
13+
return mod
14+
15+
16+
def copy_weights(from_mha, to_mha):
17+
with torch.no_grad():
18+
to_mha.W_query.copy_(from_mha.W_query.weight.T)
19+
to_mha.W_key.copy_(from_mha.W_key.weight.T)
20+
to_mha.W_value.copy_(from_mha.W_value.weight.T)
21+
22+
to_mha.out_proj.weight.copy_(from_mha.out_proj.weight)
23+
to_mha.out_proj.bias.copy_(from_mha.out_proj.bias)
24+
25+
26+
@pytest.mark.parametrize(
27+
"d_in,d_out,batch,seq_len,num_heads,seed",
28+
[
29+
(768, 768, 2, 4, 12, 123), # d_in == d_out
30+
(768, 1536, 2, 4, 12, 456), # d_in != d_out
31+
(1024, 512, 2, 4, 8, 789), # d_in > d_out
32+
],
33+
)
34+
def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, nb_imports):
35+
torch.manual_seed(seed)
36+
37+
x = torch.randn(batch, seq_len, d_in)
38+
39+
mha_linear = nb_imports.Ch03_MHA(
40+
d_in=d_in,
41+
d_out=d_out,
42+
context_length=seq_len,
43+
dropout=0.0,
44+
num_heads=num_heads,
45+
qkv_bias=False,
46+
).eval()
47+
48+
mha_einsum = nb_imports.MHAEinsum(
49+
d_in=d_in,
50+
d_out=d_out,
51+
context_length=seq_len,
52+
dropout=0.0,
53+
num_heads=num_heads,
54+
qkv_bias=False,
55+
).eval()
56+
57+
copy_weights(mha_linear, mha_einsum)
58+
59+
out_linear = mha_linear(x)
60+
out_einsum = mha_einsum(x)
61+
62+
assert out_linear.shape == out_einsum.shape == torch.Size([batch, seq_len, d_out])
63+
assert torch.allclose(out_linear, out_einsum, atol=1e-5)

0 commit comments

Comments
 (0)