Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/basic-tests-pixi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, windows-latest]

steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Configs and keys
ch05/07_gpt_to_llama/config.json
ch07/02_dataset-utilities/config.json
Expand Down Expand Up @@ -78,6 +77,11 @@ ch07/01_main-chapter-code/gpt2-medium355M-sft-standalone.pth
ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth
ch07/01_main-chapter-code/gpt2/

Qwen3-0.6B-Base/
Qwen3-0.6B/
tokenizer-base.json
tokenizer.json

# Datasets
the-verdict.txt

Expand Down
4 changes: 3 additions & 1 deletion pkg/llms_from_scratch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ For more information about KV caching, please see the [KV cache README](../../ch

```python
from llms_from_scratch.llama3 import (
Llama3Model,
load_weights_into_llama,
Llama3Model,
Llama3ModelFast,
Llama3Tokenizer,
ChatFormat,
Expand All @@ -154,6 +155,7 @@ For more information about KV caching, please see the [KV cache README](../../ch

```python
from llms_from_scratch.qwen3 import (
load_weights_into_qwen
Qwen3Model,
Qwen3Tokenizer,
)
Expand Down
74 changes: 74 additions & 0 deletions pkg/llms_from_scratch/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,77 @@ def forward(self, in_idx):
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits


def assign(left, right, tensor_name="unknown"):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")

if isinstance(right, torch.Tensor):
return torch.nn.Parameter(right.clone().detach())
else:
return torch.nn.Parameter(torch.tensor(right))


def load_weights_into_llama(model, param_config, params):
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")

for l in range(param_config["n_layers"]):

# Load attention weights
model.trf_blocks[l].att.W_query.weight = assign(
model.trf_blocks[l].att.W_query.weight,
params[f"model.layers.{l}.self_attn.q_proj.weight"],
f"model.layers.{l}.self_attn.q_proj.weight"
)
model.trf_blocks[l].att.W_key.weight = assign(
model.trf_blocks[l].att.W_key.weight,
params[f"model.layers.{l}.self_attn.k_proj.weight"],
f"model.layers.{l}.self_attn.k_proj.weight"
)
model.trf_blocks[l].att.W_value.weight = assign(
model.trf_blocks[l].att.W_value.weight,
params[f"model.layers.{l}.self_attn.v_proj.weight"],
f"model.layers.{l}.self_attn.v_proj.weight"
)
model.trf_blocks[l].att.out_proj.weight = assign(
model.trf_blocks[l].att.out_proj.weight,
params[f"model.layers.{l}.self_attn.o_proj.weight"],
f"model.layers.{l}.self_attn.o_proj.weight"
)
model.trf_blocks[l].norm1.weight = assign(
model.trf_blocks[l].norm1.weight,
params[f"model.layers.{l}.input_layernorm.weight"],
f"model.layers.{l}.input_layernorm.weight"
)

# Load FeedForward weights
model.trf_blocks[l].ff.fc1.weight = assign(
model.trf_blocks[l].ff.fc1.weight,
params[f"model.layers.{l}.mlp.gate_proj.weight"],
f"model.layers.{l}.mlp.gate_proj.weight"
)
model.trf_blocks[l].ff.fc2.weight = assign(
model.trf_blocks[l].ff.fc2.weight,
params[f"model.layers.{l}.mlp.up_proj.weight"],
f"model.layers.{l}.mlp.up_proj.weight"
)
model.trf_blocks[l].ff.fc3.weight = assign(
model.trf_blocks[l].ff.fc3.weight,
params[f"model.layers.{l}.mlp.down_proj.weight"],
f"model.layers.{l}.mlp.down_proj.weight"
)
model.trf_blocks[l].norm2.weight = assign(
model.trf_blocks[l].norm2.weight,
params[f"model.layers.{l}.post_attention_layernorm.weight"],
f"model.layers.{l}.post_attention_layernorm.weight"
)

# Load output layer weights
model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")

if "lm_head.weight" in params.keys():
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
else:
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
print("Model uses weight tying.")
63 changes: 61 additions & 2 deletions pkg/llms_from_scratch/tests/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.llama3 import (
compute_rope_params,
apply_rope,
LLAMA32_CONFIG_1B,
compute_rope_params,
GroupedQueryAttention,
GroupedQueryAttentionFast,
load_weights_into_llama,
LLAMA32_CONFIG_1B,
Llama3Model,
)
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
Expand Down Expand Up @@ -246,3 +247,61 @@ def test_rmsnorm_equivalence():
out2 = lit_norm(x)

torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)


@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_llama3_base_equivalence_with_transformers():
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
cfg = {
"vocab_size": 257,
"context_length": 8192,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"n_kv_groups": 2,
"rope_base": 500_000.0,
"rope_freq": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
},
"dtype": torch.float32,
}

ours = Llama3Model(cfg)

hf_cfg = LlamaConfig(
vocab_size=cfg["vocab_size"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_key_value_heads=cfg["n_kv_groups"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
max_position_embeddings=cfg["context_length"],
rms_norm_eps=1e-5,
attention_bias=False,
rope_theta=cfg["rope_base"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
rope_scaling={
"type": "llama3",
"factor": cfg["rope_freq"]["factor"],
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
},
)
theirs = LlamaForCausalLM(hf_cfg)

hf_state = theirs.state_dict()
load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)

x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
ours_logits = ours(x)
theirs_logits = theirs(x).logits.to(ours_logits.dtype)

torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
58 changes: 55 additions & 3 deletions pkg/llms_from_scratch/tests/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.qwen3 import (
compute_rope_params,
apply_rope,
compute_rope_params,
load_weights_into_qwen,
QWEN_CONFIG_06_B,
RMSNorm,
Qwen3Model,
Qwen3Tokenizer
Qwen3Tokenizer,
RMSNorm,
)
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
from llms_from_scratch.kv_cache.utils import KVCache
Expand Down Expand Up @@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base):
return cfg


@torch.inference_mode()
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_base)
Expand All @@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"


@torch.inference_mode()
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_moe)
Expand All @@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
"Expected MoEFeedForward in at least one transformer block"


@torch.inference_mode()
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
def test_qwen3_kvcache_equivalence(cfg_name, request):
cfg = request.getfixturevalue(cfg_name)
Expand Down Expand Up @@ -438,3 +442,51 @@ def test_tokenizer_equivalence():
expected_pad_token = "<|endoftext|>"
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token


@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_base_equivalence_with_transformers():

from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM

# Tiny config so the test is fast
cfg = {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"dtype": torch.float32,
}
model = Qwen3Model(cfg)

hf_cfg = Qwen3Config(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
rope_theta=cfg["rope_base"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
)
hf_model = Qwen3ForCausalLM(hf_cfg)

hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
load_weights_into_qwen(model, param_config, hf_state)

x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)
theirs_logits = hf_model(x).logits
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
Loading