From 8a3f12d20d6a0b47e48ce2102863d800533ad552 Mon Sep 17 00:00:00 2001 From: rasbt Date: Sun, 19 Oct 2025 21:42:34 -0500 Subject: [PATCH 1/4] Mixture-of-Experts intro --- .gitignore | 5 +- README.md | 1 + ch04/06_swa/README.md | 6 +- ch04/06_swa/memory_estimator_mla.py | 123 ----- ch04/06_swa/plot_memory_estimates_mla.py | 90 ---- ch04/07_moe/README.md | 160 ++++++ ch04/07_moe/ffn_moe_memory_estimator.py | 127 +++++ ch04/07_moe/gpt_with_kv_ffn.py | 415 +++++++++++++++ ch04/07_moe/gpt_with_kv_moe.py | 490 ++++++++++++++++++ ch04/07_moe/plot_memory_estimates_moe.py | 123 +++++ ch04/README.md | 1 + .../standalone-qwen3-moe-plus-kvcache.ipynb | 10 +- ch05/11_qwen3/standalone-qwen3-moe.ipynb | 12 +- 13 files changed, 1334 insertions(+), 229 deletions(-) delete mode 100644 ch04/06_swa/memory_estimator_mla.py delete mode 100644 ch04/06_swa/plot_memory_estimates_mla.py create mode 100644 ch04/07_moe/README.md create mode 100644 ch04/07_moe/ffn_moe_memory_estimator.py create mode 100644 ch04/07_moe/gpt_with_kv_ffn.py create mode 100644 ch04/07_moe/gpt_with_kv_moe.py create mode 100644 ch04/07_moe/plot_memory_estimates_moe.py diff --git a/.gitignore b/.gitignore index ef8f7fec5..f425155fe 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,9 @@ appendix-D/01_main-chapter-code/3.pdf appendix-E/01_main-chapter-code/loss-plot.pdf ch04/04_gqa/kv_bytes_vs_context_length.pdf -ch05/05_mla/kv_bytes_vs_context_length.pdf -ch06/06_swa/kv_bytes_vs_context_length.pdf +ch04/05_mla/kv_bytes_vs_context_length.pdf +ch04/06_swa/kv_bytes_vs_context_length.pdf +ch04/07_moe/ffn_vs_moe.pdf ch05/01_main-chapter-code/loss-plot.pdf ch05/01_main-chapter-code/temperature-plot.pdf diff --git a/README.md b/README.md index dd189dace..9c3c3edec 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,7 @@ Several folders contain optional materials as a bonus for interested readers: - [Grouped-Query Attention](ch04/04_gqa) - [Multi-Head Latent Attention](ch04/05_mla) - [Sliding Window Attention](ch04/06_swa) + - [Mixture-of-Experts (MoE)](ch04/07_moe) - **Chapter 5: Pretraining on unlabeled data:** - [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/) - [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg) diff --git a/ch04/06_swa/README.md b/ch04/06_swa/README.md index 0520e7546..5ee775257 100644 --- a/ch04/06_swa/README.md +++ b/ch04/06_swa/README.md @@ -71,14 +71,14 @@ The savings when using SWA over MHA are further shown in the plot below for diff   -SWA +SWA   -You can reproduce these plots via: +You can reproduce thi plots via: ```bash -plot_memory_estimates_swa.py \ +uv run plot_memory_estimates_swa.py \ --emb_dim 4096 --n_heads 48 --n_layers 36 \ --batch_size 1 --dtype bf16 \ --sliding_window_size 2048 --swa_ratio "5:1" diff --git a/ch04/06_swa/memory_estimator_mla.py b/ch04/06_swa/memory_estimator_mla.py deleted file mode 100644 index f9ab9f512..000000000 --- a/ch04/06_swa/memory_estimator_mla.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch -# -# KV-cache memory estimator for MHA vs GQA vs MLA - -import argparse -import math - -DTYPE_BYTES = { - "fp32": 4, - "bf16": 2, - "fp16": 2, - "fp8": 1, - "int8": 1, -} - - -def bytes_convert(n): - gb = n / (1000 ** 3) - return f"{gb:,.2f} GB" - - -def kv_bytes_total(batch, context_length, emb_dim, n_heads, - n_kv_heads, n_layers, bytes_per_elem): - # Generic KV-cache: per-head dim is embed_dim / n_heads, times 2 for K and V - head_dim = math.ceil(emb_dim / n_heads) - per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem - return per_layer * n_layers - - -def mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem): - # Simple MLA (per-token compressed latent) - # bytes ≈ batch × seqlen × n_layers × latent_dim × bytes_per_elem - return batch * context_length * n_layers * latent_dim * bytes_per_elem - - -def main(): - p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA vs GQA vs MLA") - p.add_argument("--context_length", default=1024, type=int) - p.add_argument("--emb_dim", required=True, type=int) - p.add_argument("--n_heads", required=True, type=int) - p.add_argument("--n_layers", required=True, type=int) - p.add_argument("--n_kv_groups", required=True, type=int) - p.add_argument("--latent_dim", required=True, type=int, help="MLA per-token latent dimension") - p.add_argument("--batch_size", default=1, type=int) - p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="fp16") - args = p.parse_args() - - cfg = { - "context_length": args.context_length, - "emb_dim": args.emb_dim, - "n_heads": args.n_heads, - "n_layers": args.n_layers, - "n_kv_groups": args.n_kv_groups, - "latent_dim": args.latent_dim, - } - - if cfg["n_heads"] % cfg["n_kv_groups"] != 0: - raise ValueError("n_kv_groups must divide n_heads exactly.") - - bytes_per_elem = DTYPE_BYTES[args.dtype] - head_dim = math.ceil(cfg["emb_dim"] / cfg["n_heads"]) - - n_kv_heads_mha = cfg["n_heads"] - n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"] - - total_mha = kv_bytes_total( - args.batch_size, - cfg["context_length"], - cfg["emb_dim"], - cfg["n_heads"], - n_kv_heads_mha, - cfg["n_layers"], - bytes_per_elem, - ) - - total_gqa = kv_bytes_total( - args.batch_size, - cfg["context_length"], - cfg["emb_dim"], - cfg["n_heads"], - n_kv_heads_gqa, - cfg["n_layers"], - bytes_per_elem, - ) - - total_mla = mla_bytes_total( - args.batch_size, - cfg["context_length"], - cfg["n_layers"], - cfg["latent_dim"], - bytes_per_elem, - ) - - ratio = total_mha / total_gqa if total_gqa != 0 else float("inf") - savings = 1 - (total_gqa / total_mha) if total_mha != 0 else 0.0 - - ratio_mha_mla = total_mha / total_mla if total_mla != 0 else float("inf") - savings_mla = 1 - (total_mla / total_mha) if total_mha != 0 else 0.0 - - print("==== Config ====") - for k, v in cfg.items(): - print(f"{k:17}: {v}") - print(f"batch_size : {args.batch_size}") - print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)") - print(f"head_dim : {head_dim}") - print(f"GQA n_kv_heads : {n_kv_heads_gqa}") - print() - - print("==== KV-cache totals across all layers ====") - print(f"MHA total KV cache : {bytes_convert(total_mha)}") - print(f"GQA total KV cache : {bytes_convert(total_gqa)}") - print(f"MLA total KV cache : {bytes_convert(total_mla)}") - print(f"Ratio (MHA / GQA) : {ratio:,.2f}x") - print(f"Savings (GQA vs MHA): {savings*100:,.2f}%") - print(f"Ratio (MHA / MLA) : {ratio_mha_mla:,.2f}x") - print(f"Savings (MLA vs MHA): {savings_mla*100:,.2f}%") - - -if __name__ == "__main__": - main() diff --git a/ch04/06_swa/plot_memory_estimates_mla.py b/ch04/06_swa/plot_memory_estimates_mla.py deleted file mode 100644 index e4c420880..000000000 --- a/ch04/06_swa/plot_memory_estimates_mla.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch - -import matplotlib.pyplot as plt - -# Bytes per element -DTYPE_BYTES = { - "fp32": 4, - "bf16": 2, - "fp16": 2, - "fp8": 1, - "int8": 1, -} - - -def bytes_to_gb(n_bytes): - return n_bytes / (1000. ** 3) - - -def kv_bytes_total_mha(batch, context_length, emb_dim, n_heads, - n_layers, bytes_per_elem): - head_dim = emb_dim / n_heads - per_layer = batch * context_length * head_dim * n_heads * 2 * bytes_per_elem - return per_layer * n_layers - - -def kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem): - return batch * context_length * n_layers * latent_dim * bytes_per_elem - - -def plot_abs_kv_vs_context_multiple(): - n_heads = 24 - emb_dim = 2048 - n_layers = 48 - batch_size = 1 - dtype = "bf16" - bytes_per_elem = DTYPE_BYTES[dtype] - - context_lengths = [ - 256, 512, 1024, 2048, 4096, 8192, - 16384, 32768, 65536, 131072 - ] - - mha_gb = [] - for L in context_lengths: - total_mha = kv_bytes_total_mha( - batch_size, L, emb_dim, n_heads, n_layers, bytes_per_elem - ) - mha_gb.append(bytes_to_gb(total_mha)) - - latent_dims = [1024, 512, 256, 64] - plt.figure() - plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)") - - L_ref = context_lengths[-1] - total_mha_ref = kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem) - - for latent_dim in latent_dims: - mla_gb = [] - for L in context_lengths: - total_mla = kv_bytes_total_mla( - batch_size, L, n_layers, latent_dim, bytes_per_elem - ) - mla_gb.append(bytes_to_gb(total_mla)) - - total_mla_ref = kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem) - comp = total_mha_ref / total_mla_ref if total_mla_ref != 0 else float("inf") - - plt.plot(context_lengths, mla_gb, marker="o", - label=f"MLA (latent_dim={latent_dim}, {comp:,.1f}× compression)") - - plt.xscale("log") - plt.xlabel("context_length (log scale)") - plt.ylabel("Total KV cache (GB)") - plt.title( - "KV-cache vs Context Length — MHA vs MLA\n" - f"(n_heads={n_heads}, emb_dim={emb_dim}, n_layers={n_layers}, " - f"batch={batch_size}, dtype={dtype})", - fontsize=8 - ) - plt.grid(True, which="both") - plt.legend() - plt.tight_layout() - plt.savefig("kv_bytes_vs_context_length.pdf") - - -if __name__ == "__main__": - plot_abs_kv_vs_context_multiple() diff --git a/ch04/07_moe/README.md b/ch04/07_moe/README.md new file mode 100644 index 000000000..440adb2f8 --- /dev/null +++ b/ch04/07_moe/README.md @@ -0,0 +1,160 @@ +# Mixture of Experts (MoE) + +This bonus material illustrates the memory savings (per token) when using Mixture-of-Experts (MoE) layers instead of regular feed-forward (FFN) layers. + + + +  +## Introduction + +The core idea in MoE is to replace each feed-forward module in a transformer block with multiple expert layers, where each of these expert layers is also a feed-forward module. This means we replace a single feed-forward block with multiple feed-forward blocks, as illustrated in the figure below. + + + +  + +SWA + +The feed-forward block inside a transformer block (shown as the dark gray block in the figure above) typically contains a large number of the model's total parameters. (Note that the transformer block, and thereby the feed-forward block, is repeated many times in an LLM; in the case of DeepSeek-V3, 61 times.) + +So, replacing *a single* feed-forward block with *multiple* feed-forward blocks (as done in a MoE setup) substantially increases the model's total parameter count. However, the key trick is that we don't use ("activate") all experts for every token. Instead, a router selects only a small subset of experts per token. + +Because only a few experts are active at a time, MoE modules are often referred to as *sparse*, in contrast to *dense* modules that always use the full parameter set. However, the large total number of parameters via an MoE increases the capacity of the LLM, which means it can take up more knowledge during training. The sparsity keeps inference efficient, though, as we don't use all the parameters at the same time. + +For example, DeepSeek-V3 has 256 experts per MoE module and a total of 671 billion parameters. Yet during inference, only 9 experts are active at a time (1 shared expert plus 8 selected by the router). This means just 37 billion parameters are used for each token inference step as opposed to all 671 billion. + +One notable feature of DeepSeek-V3's MoE design is the use of a shared expert. This is an expert that is always active for every token. This idea is not new and was already introduced in the [2022 DeepSeek MoE](https://arxiv.org/abs/2201.05596) and the [2024 DeepSeek MoE](https://arxiv.org/abs/2201.05596) papers. + +  + +MoE shared expert + +(An annotated figure from the [DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models](https://arxiv.org/abs/2401.06066) paper.) + +  + +The benefit of having a shared expert was first noted in the [DeepSpeedMoE paper](https://arxiv.org/abs/2201.05596), where they found that it boosts overall modeling performance compared to no shared experts. This is likely because common or repeated patterns don't have to be learned by multiple individual experts, which leaves them with more room for learning more specialized patterns. + +  +## Mixture of Experts (MoE) Memory Savings + +The memory savings in MoE models primarily come from reduced activation storage and compute. In a regular (dense) feed-forward layer (FFN), every token activates the full intermediate dimension. + +In contrast, an MoE layer routes each token through only a small subset of experts (for example, `top_k` out of `num_experts`) per token. + +When using an MoE layer, only `top_k` experts are active per token, so the effective memory (and compute) scales by roughly a factor of `top_k / num_experts` relative to a dense FFN of the same total capacity. + + +You can use the [memory_estimator_moe.py](memory_estimator_moe.py) script in this folder to apply this for different model configs to see how much memory you can save by using MoE over FFN (note that this is for a single transformer block, to get the total savings, multiply by the number of transformer blocks in your model): + +```bash +uv run ffn_moe_memory_estimator.py --emb_dim 7168 --hidden_dim 14336 --ffn_type swiglu \ + --num_experts 8 --top_k 2 --match_dense +==== Config ==== +emb_dim : 7168 +hidden_size : 14336 +ffn_type : swiglu +num_experts : 8 +top_k : 2 +dtype : bf16 (2 Bytes/elem) +match_dense : True + +==== Model weights (parameters) ==== +Dense FFN params : 308,281,344 (0.62 GB) +Per-expert params : 38,535,168 (0.08 GB) +Router params : 57,344 (0.00 GB) +MoE TOTAL params : 308,338,688 (0.62 GB) +MoE ACTIVE/Token : 77,127,680 (0.15 GB) +moe_hidden_size : 1792 +``` + +So, based on the results above, we can see that if we have a FFN with an input/output dimension (`emb_dim`) of 7,168 and an intermediate size (`hidden_dim`) of 14,336, we have ~308M parameters in this layer, and all these parameters are active in the forward pass. + +Now, if we use an MoE layer with roughly the same number of total parameters (~308M), with 8 experts where 2 experts are active, only ~77M parameters are active in each forward pass. + +Moreover, at a constant number of experts, the more experts we have, the lower the number of active parameters becomes, and the greater the "savings": + +  + +  + +SWA + + + +  + +You can reproduce this plot via: + +```bash +uv run plot_memory_estimates_moe.py \ + --emb_dim 7168 \ + --hidden_dim 28672 \ + --ffn_type swiglu \ + --top_k 8 +``` + + +  +## MoE Code Examples + +The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_moe.py](gpt_with_kv_moe.py) scripts in this folder provide hands-on examples for comparing the regular FFN and MoE memory usage in the context of a GPT model implementation. Note that both scripts use [SwiGLU](https://arxiv.org/abs/2002.05202) feed-forward modules as shown in the first figure of this page (GPT-2 traditionally uses GELU). + +**Note: The model is not trained and thus generates nonsensical text. You can find a trained MoE in the bonus materials at [../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb](../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb).** + + + +First, let's run the model with a regular FFN: + + +```bash +uv run gpt_with_kv_ffn.py \ +--max_new_tokens 1024 \ +--n_heads 16 \ +--n_layers 12 \ +--emb_dim 4096 \ +--hidden_dim 32768 + +... +Avg FFN time/call: 0.759 ms +Avg FFN mem delta/call: 0.19 MB (max 0.75 MB) +... +Time: 25.13 sec +40 tokens/sec +Max memory allocated: 11.47 GB +``` + +For a fair comparison with an MoE, we have to shrink the expert size. E.g., of we use 32 experts, we have to set `--hidden_dim 32768/32`: + + +```bash +uv run gpt_with_kv_moe.py \ +--max_new_tokens 1024 \ +--n_heads 16 \ +--n_layers 12 \ +--emb_dim 4096 \ +--hidden_dim 1024 \ +--num_experts 32 \ +--num_experts_per_tok 2 + +... +Avg MoE FF time/call: 1.555 ms +Avg MoE FF mem delta/call: 0.04 MB (max 0.11 MB) +... +Time: 35.11 sec +29 tokens/sec +Max memory allocated: 11.48 GB +``` + +We can see that the dense feed-forward layer processes a token in about 0.76 ms and uses roughly 0.19 MB of activations (peaking near 0.75 MB), + +The sparse MoE layer keeps only about 0.04 MB of memory (peaking at 0.11). However, this comes at the cost of roughly twice the compute time. (There is an added routing overhead, and my implementation may also not be the most efficient one.) + +Overall generation still peaks around 11.5 GB of GPU memory in both cases, since both versions load the same number of weight parameters and have the same KV cache size, which dominate here. + +Either way, we can see the trade-off here where MoE reduces the FFN memory by about 4-5× while roughly doubling the feed-forward compute time. + +Note that if we processed more tokens at one, e.g., with a batch size larger than 1 (here we don't have batches due to code simplicity), the savings would be more pronounced. + + + diff --git a/ch04/07_moe/ffn_moe_memory_estimator.py b/ch04/07_moe/ffn_moe_memory_estimator.py new file mode 100644 index 000000000..6fa7839d1 --- /dev/null +++ b/ch04/07_moe/ffn_moe_memory_estimator.py @@ -0,0 +1,127 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import argparse + +DTYPE_BYTES = { + "fp32": 4, + "bf16": 2, + "fp16": 2, + "fp8": 1, + "int8": 1, +} + + +def bytes_convert(n): + gb = n / (1000 ** 3) + return f"{gb:,.2f} GB" + + +def get_num_param_matrices(ffn_type): + if ffn_type == "gelu": + return 2 + elif ffn_type == "swiglu": + return 3 + else: + raise ValueError("--ffn_type must be 'gelu' or 'swiglu'") + + +def ffn_params(emb_dim, hidden_dim, ffn_type): + return get_num_param_matrices(ffn_type) * emb_dim * hidden_dim + + +def router_params(emb_dim, num_experts): + return emb_dim * num_experts + + +def estimate_params_and_hidden( + emb_dim, hidden_dim, ffn_type, num_experts, match_dense=False +): + P_dense = ffn_params(emb_dim, hidden_dim, ffn_type) + R = router_params(emb_dim, num_experts) + + if match_dense: + num_param_matrices = get_num_param_matrices(ffn_type) + num = P_dense - R + den = num_experts * num_param_matrices * emb_dim + if num <= 0: + raise ValueError("Dense layer too small for requested num_experts.") + moe_hidden_dim = int(round(num / float(den))) + else: + moe_hidden_dim = hidden_dim + + per_expert_params = ffn_params(emb_dim, moe_hidden_dim, ffn_type) + moe_total = num_experts * per_expert_params + R + + return { + "dense_params": P_dense, + "router": R, + "moe_hidden_dim": moe_hidden_dim, + "per_expert_params": per_expert_params, + "moe_total": moe_total, + } + + +def main(): + p = argparse.ArgumentParser( + description="Estimate FFN vs MoE parameter memory" + ) + p.add_argument("--emb_dim", type=int, required=True, + help="Model embedding dimension.") + p.add_argument("--hidden_dim", type=int, required=True, + help="Dense FFN intermediate size (hidden dimension).") + p.add_argument("--ffn_type", choices=["gelu", "swiglu"], default="swiglu") + p.add_argument("--num_experts", type=int, default=8) + p.add_argument("--top_k", type=int, default=2) + p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="bf16") + p.add_argument( + "--match_dense", + action="store_true", + help=("Auto-set per-expert hidden so MoE total params ~= dense FFN params " + "(router included)."), + ) + args = p.parse_args() + + bytes_per_elem = DTYPE_BYTES[args.dtype] + + res = estimate_params_and_hidden( + emb_dim=args.emb_dim, + hidden_dim=args.hidden_dim, + ffn_type=args.ffn_type, + num_experts=args.num_experts, + match_dense=args.match_dense, + ) + + moe_active_params_per_token = ( + res["router"] + args.top_k * res["per_expert_params"] + ) + + print("==== Config ====") + print(f"{'emb_dim':23}: {args.emb_dim}") + print(f"{'hidden_dim':23}: {args.hidden_dim}") + print(f"{'ffn_type':23}: {args.ffn_type}") + print(f"{'num_experts':23}: {args.num_experts}") + print(f"{'top_k':23}: {args.top_k}") + print(f"{'dtype':23}: {args.dtype} ({bytes_per_elem} Bytes/elem)") + print(f"{'match_dense':23}: {args.match_dense}") + print() + + print("==== Model weights (parameters) ====") + print(f"{'Dense FFN params':23}: {res['dense_params']:,} " + f"({bytes_convert(res['dense_params'] * bytes_per_elem)})") + print(f"{'Per-expert params':23}: {res['per_expert_params']:,} " + f"({bytes_convert(res['per_expert_params'] * bytes_per_elem)})") + print(f"{'Router params':23}: {res['router']:,} " + f"({bytes_convert(res['router'] * bytes_per_elem)})") + print(f"{'MoE TOTAL params':23}: {res['moe_total']:,} " + f"({bytes_convert(res['moe_total'] * bytes_per_elem)})") + print(f"{'MoE ACTIVE/Token':23}: {moe_active_params_per_token:,} " + f"({bytes_convert(moe_active_params_per_token * bytes_per_elem)})") + print(f"{'moe_hidden_dim':23}: {res['moe_hidden_dim']}") + print() + + +if __name__ == "__main__": + main() diff --git a/ch04/07_moe/gpt_with_kv_ffn.py b/ch04/07_moe/gpt_with_kv_ffn.py new file mode 100644 index 000000000..a6035e0d3 --- /dev/null +++ b/ch04/07_moe/gpt_with_kv_ffn.py @@ -0,0 +1,415 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# This file collects all the relevant code that we covered thus far +# throughout Chapters 3-4. +# This file can be run as a standalone script. + +import argparse +import time +import tiktoken +import torch +import torch.nn as nn + +FFN_TIME_MS = [] +FFN_MEM_BYTES = [] + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + + #################################################### + # KV cache-related code + self.register_buffer("cache_k", None, persistent=False) + self.register_buffer("cache_v", None, persistent=False) + self.ptr_current_pos = 0 + #################################################### + + def forward(self, x, use_cache=False): + b, num_tokens, d_in = x.shape + + keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) + values_new = self.W_value(x) + queries = self.W_query(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim) + values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + #################################################### + # KV cache-related + if use_cache: + if self.cache_k is None: + self.cache_k, self.cache_v = keys_new, values_new + else: + self.cache_k = torch.cat([self.cache_k, keys_new], dim=1) + self.cache_v = torch.cat([self.cache_v, values_new], dim=1) + keys, values = self.cache_k, self.cache_v + else: + keys, values = keys_new, values_new + #################################################### + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + #################################################### + # causal mask + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + device = queries.device + if use_cache: + q_positions = torch.arange( + self.ptr_current_pos, + self.ptr_current_pos + num_tokens_Q, + device=device, + dtype=torch.long, + ) + self.ptr_current_pos += num_tokens_Q + else: + q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long) + self.ptr_current_pos = 0 + k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long) + mask_bool = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0) + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + def reset_cache(self): + self.cache_k, self.cache_v = None, None + self.ptr_current_pos = 0 + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +# class FeedForward(nn.Module): +# def __init__(self, cfg): +# super().__init__() +# self.layers = nn.Sequential( +# nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), +# GELU(), +# nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), +# ) + +# def forward(self, x): +# return self.layers(x) + +# Uses SwiGLU instead of GeLU to make it more comparable to MoE +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False) + self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False) + self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], bias=False) + + def forward(self, x): + return self.fc3(torch.nn.functional.silu(self.fc1(x)) * self.fc2(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"], + ) + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x, use_cache=False): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + + # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + #################################################### + # KV cache-related + x = self.att(x, use_cache=use_cache) + #################################################### + + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + use_cuda = torch.cuda.is_available() + if use_cuda: + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + base_mem = torch.cuda.memory_allocated() + start = time.perf_counter() + x = self.ff(x) + if use_cuda: + torch.cuda.synchronize() + peak_mem = torch.cuda.max_memory_allocated() + FFN_MEM_BYTES.append(peak_mem - base_mem) + FFN_TIME_MS.append((time.perf_counter() - start) * 1000.0) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + # self.trf_blocks = nn.Sequential( + # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + #################################################### + # KV cache-related + self.trf_blocks = nn.ModuleList( + [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.current_pos = 0 + #################################################### + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx, use_cache=False): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + + # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + + #################################################### + # KV cache-related + if use_cache: + pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.current_pos += seq_len + else: + pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) + #################################################### + + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + + # x = self.trf_blocks(x) + #################################################### + # KV cache-related + for blk in self.trf_blocks: + x = blk(x, use_cache=use_cache) + #################################################### + + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + #################################################### + # KV cache-related + def reset_kv_cache(self): + for blk in self.trf_blocks: + blk.att.reset_cache() + self.current_pos = 0 + #################################################### + + +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): + model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings + batch_size, base_len = idx.shape + total_len = base_len + max_new_tokens + generated = torch.empty( + batch_size, total_len, dtype=idx.dtype, device=idx.device + ) + generated[:, :base_len] = idx + cur_len = base_len + use_cuda = torch.cuda.is_available() + FFN_TIME_MS.clear() + FFN_MEM_BYTES.clear() + + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + prompt_start = max(0, cur_len - ctx_len) + logits = model(generated[:, prompt_start:cur_len], use_cache=True) + + if use_cuda: + torch.cuda.synchronize() + + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1) + # b) append it to the running sequence (in-place) + generated[:, cur_len] = next_idx + cur_len += 1 + # c) feed model only the new token + logits = model(generated[:, cur_len - 1 : cur_len], use_cache=True) + + if use_cuda: + torch.cuda.synchronize() + else: + if use_cuda: + torch.cuda.synchronize() + + for _ in range(max_new_tokens): + start_ctx = max(0, cur_len - ctx_len) + logits = model(generated[:, start_ctx:cur_len], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1) + generated[:, cur_len] = next_idx + cur_len += 1 + + if use_cuda: + torch.cuda.synchronize() + + if FFN_TIME_MS: + avg_ffn_time = sum(FFN_TIME_MS) / len(FFN_TIME_MS) + print(f"Avg FFN time/call: {avg_ffn_time:.3f} ms") + if FFN_MEM_BYTES: + avg_ffn_mem = sum(FFN_MEM_BYTES) / len(FFN_MEM_BYTES) + max_ffn_mem = max(FFN_MEM_BYTES) + + def to_mb(bytes_val): + return bytes_val / (1024 ** 2) + print(f"Avg FFN mem delta/call: {to_mb(avg_ffn_mem):.2f} MB (max {to_mb(max_ffn_mem):.2f} MB)") + + return generated[:, :cur_len] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.") + parser.add_argument("--hidden_dim", type=int, default=768*4, help="Intermediate FFN size.") + parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.") + parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") + parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + parser.add_argument( + "--no_kv_cache", + action="store_true", + help="Disable KV caching during generation.", + ) + + args = parser.parse_args() + + start_context = "Hello, I am" + tokenizer = tiktoken.get_encoding("gpt2") + encoded = tokenizer.encode(start_context) + + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": args.max_new_tokens + len(encoded), + "emb_dim": args.emb_dim, # Embedding dimension + "hidden_dim": args.hidden_dim, # Intermediate size + "n_heads": args.n_heads, # Number of attention heads + "n_layers": args.n_layers, # Number of layers + "drop_rate": 0.0, # Dropout rate + "qkv_bias": False, # Query-Key-Value bias + } + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device, dtype=torch.bfloat16) + model.eval() # disable dropout + + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") + print("\nInput text:", start_context) + print("Encoded input text:", encoded) + print("encoded_tensor.shape:", encoded_tensor.shape) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=args.max_new_tokens, + use_cache=not args.no_kv_cache, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_time = time.time() - start + + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) + + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") + print("\nOutput:", token_ids) + print("Output length:", len(token_ids[0])) + print("Output text:", decoded_text) + + print(f"\nTime: {total_time:.2f} sec") + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") + if torch.cuda.is_available(): + max_mem_bytes = torch.cuda.max_memory_allocated() + max_mem_gb = max_mem_bytes / (1024 ** 3) + print(f"Max memory allocated: {max_mem_gb:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/ch04/07_moe/gpt_with_kv_moe.py b/ch04/07_moe/gpt_with_kv_moe.py new file mode 100644 index 000000000..498c98747 --- /dev/null +++ b/ch04/07_moe/gpt_with_kv_moe.py @@ -0,0 +1,490 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# This file collects all the relevant code that we covered thus far +# throughout Chapters 3-4. +# This file can be run as a standalone script. + +import argparse +import time +import tiktoken +import torch +import torch.nn as nn + +MOE_FF_TIME_MS = [] +MOE_FF_MEM_BYTES = [] + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + + #################################################### + # KV cache-related code + self.register_buffer("cache_k", None, persistent=False) + self.register_buffer("cache_v", None, persistent=False) + self.ptr_current_pos = 0 + #################################################### + + def forward(self, x, use_cache=False): + b, num_tokens, d_in = x.shape + + keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) + values_new = self.W_value(x) + queries = self.W_query(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim) + values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + #################################################### + # KV cache-related + if use_cache: + if self.cache_k is None: + self.cache_k, self.cache_v = keys_new, values_new + else: + self.cache_k = torch.cat([self.cache_k, keys_new], dim=1) + self.cache_v = torch.cat([self.cache_v, values_new], dim=1) + keys, values = self.cache_k, self.cache_v + else: + keys, values = keys_new, values_new + #################################################### + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + #################################################### + # causal mask + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + device = queries.device + if use_cache: + q_positions = torch.arange( + self.ptr_current_pos, + self.ptr_current_pos + num_tokens_Q, + device=device, + dtype=torch.long, + ) + self.ptr_current_pos += num_tokens_Q + else: + q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long) + self.ptr_current_pos = 0 + k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long) + mask_bool = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0) + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + def reset_cache(self): + self.cache_k, self.cache_v = None, None + self.ptr_current_pos = 0 + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], cfg["hidden_dim"]), + GELU(), + nn.Linear(cfg["hidden_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class MoEFeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.num_experts_per_tok = cfg["num_experts_per_tok"] + self.num_experts = cfg["num_experts"] + self.emb_dim = cfg["emb_dim"] + + self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False) + self.fc1 = nn.ModuleList( + [ + nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False) + for _ in range(self.num_experts) + ] + ) + self.fc2 = nn.ModuleList( + [ + nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False) + for _ in range(self.num_experts) + ] + ) + self.fc3 = nn.ModuleList( + [ + nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], bias=False) + for _ in range(self.num_experts) + ] + ) + + def forward(self, x): + # x: (batch, seq_len, emb_dim) + scores = self.gate(x) # (b, seq_len, num_experts) + topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1) + topk_probs = torch.softmax(topk_scores, dim=-1) + + batch, seq_len, _ = x.shape + x_flat = x.reshape(batch * seq_len, -1) + out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype) + + topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok) + topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok) + + unique_experts = torch.unique(topk_indices_flat) + + for expert_id_tensor in unique_experts: + expert_id = int(expert_id_tensor.item()) + + mask = topk_indices_flat == expert_id + if not mask.any(): + continue + + token_mask = mask.any(dim=-1) + selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1) + if selected_idx.numel() == 0: + continue + + expert_input = x_flat.index_select(0, selected_idx) + hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[ + expert_id + ](expert_input) + expert_out = self.fc3[expert_id](hidden) + + mask_selected = mask[selected_idx] + slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True) + selected_probs = torch.gather( + topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices + ).squeeze(-1) + + out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1)) + + return out_flat.reshape(batch, seq_len, self.emb_dim) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"], + ) + self.ff = MoEFeedForward(cfg) if cfg["num_experts"] > 0 else FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x, use_cache=False): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + + # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + #################################################### + # KV cache-related + x = self.att(x, use_cache=use_cache) + #################################################### + + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + use_cuda = torch.cuda.is_available() + if use_cuda: + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + base_mem = torch.cuda.memory_allocated() + start = time.perf_counter() + x = self.ff(x) + if use_cuda: + torch.cuda.synchronize() + peak_mem = torch.cuda.max_memory_allocated() + MOE_FF_MEM_BYTES.append(peak_mem - base_mem) + MOE_FF_TIME_MS.append((time.perf_counter() - start) * 1000.0) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + # self.trf_blocks = nn.Sequential( + # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + #################################################### + # KV cache-related + self.trf_blocks = nn.ModuleList( + [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.current_pos = 0 + #################################################### + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx, use_cache=False): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + + # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + + #################################################### + # KV cache-related + if use_cache: + pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.current_pos += seq_len + else: + pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) + #################################################### + + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + + # x = self.trf_blocks(x) + #################################################### + # KV cache-related + for blk in self.trf_blocks: + x = blk(x, use_cache=use_cache) + #################################################### + + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + #################################################### + # KV cache-related + def reset_kv_cache(self): + for blk in self.trf_blocks: + blk.att.reset_cache() + self.current_pos = 0 + #################################################### + + +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): + model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings + batch_size, base_len = idx.shape + total_len = base_len + max_new_tokens + generated = torch.empty( + batch_size, total_len, dtype=idx.dtype, device=idx.device + ) + generated[:, :base_len] = idx + cur_len = base_len + use_cuda = torch.cuda.is_available() + MOE_FF_TIME_MS.clear() + MOE_FF_MEM_BYTES.clear() + + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + prompt_start = max(0, cur_len - ctx_len) + logits = model(generated[:, prompt_start:cur_len], use_cache=True) + + if use_cuda: + torch.cuda.synchronize() + + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1) + # b) append it to the running sequence (in-place) + generated[:, cur_len] = next_idx + cur_len += 1 + # c) feed model only the new token + logits = model(generated[:, cur_len - 1 : cur_len], use_cache=True) + + if use_cuda: + torch.cuda.synchronize() + else: + if use_cuda: + torch.cuda.synchronize() + + for _ in range(max_new_tokens): + start_ctx = max(0, cur_len - ctx_len) + logits = model(generated[:, start_ctx:cur_len], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1) + generated[:, cur_len] = next_idx + cur_len += 1 + + if use_cuda: + torch.cuda.synchronize() + + if MOE_FF_TIME_MS: + avg_ffn_time = sum(MOE_FF_TIME_MS) / len(MOE_FF_TIME_MS) + print(f"Avg MoE FF time/call: {avg_ffn_time:.3f} ms") + if MOE_FF_MEM_BYTES: + avg_ffn_mem = sum(MOE_FF_MEM_BYTES) / len(MOE_FF_MEM_BYTES) + max_ffn_mem = max(MOE_FF_MEM_BYTES) + + def to_mb(bytes_val): + return bytes_val / (1024 ** 2) + print(f"Avg MoE FF mem delta/call: {to_mb(avg_ffn_mem):.2f} MB (max {to_mb(max_ffn_mem):.2f} MB)") + + return generated[:, :cur_len] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.") + parser.add_argument("--hidden_dim", type=int, default=768*4, help="Intermediate FFN or MoE size.") + parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.") + parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") + parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + parser.add_argument( + "--no_kv_cache", + action="store_true", + help="Disable KV caching during generation.", + ) + + parser.add_argument( + "--num_experts", + type=int, + default=0, + help="Number of experts. If 0, use dense FFN. If >0, use MoE.", + ) + parser.add_argument( + "--num_experts_per_tok", + type=int, + default=2, + help="Top-k experts per token when using MoE (ignored if num_experts=0).", + ) + + args = parser.parse_args() + + start_context = "Hello, I am" + tokenizer = tiktoken.get_encoding("gpt2") + encoded = tokenizer.encode(start_context) + + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": args.max_new_tokens + len(encoded), + "emb_dim": args.emb_dim, # Embedding dimension + "hidden_dim": args.hidden_dim, # Intermediate size + "n_heads": args.n_heads, # Number of attention heads + "n_layers": args.n_layers, # Number of layers + "drop_rate": 0.0, # Dropout rate + "qkv_bias": False, # Query-Key-Value bias + "num_experts": args.num_experts, + "num_experts_per_tok": args.num_experts_per_tok if args.num_experts > 0 else 0, + } + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device, dtype=torch.bfloat16) + model.eval() # disable dropout + + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") + print("\nInput text:", start_context) + print("Encoded input text:", encoded) + print("encoded_tensor.shape:", encoded_tensor.shape) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=args.max_new_tokens, + use_cache=not args.no_kv_cache, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_time = time.time() - start + + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) + + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") + print("\nOutput:", token_ids) + print("Output length:", len(token_ids[0])) + print("Output text:", decoded_text) + + print(f"\nTime: {total_time:.2f} sec") + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") + if torch.cuda.is_available(): + max_mem_bytes = torch.cuda.max_memory_allocated() + max_mem_gb = max_mem_bytes / (1024 ** 3) + print(f"Max memory allocated: {max_mem_gb:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/ch04/07_moe/plot_memory_estimates_moe.py b/ch04/07_moe/plot_memory_estimates_moe.py new file mode 100644 index 000000000..089e1c530 --- /dev/null +++ b/ch04/07_moe/plot_memory_estimates_moe.py @@ -0,0 +1,123 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + + +import argparse +import matplotlib.pyplot as plt +from ffn_moe_memory_estimator import ( + estimate_params_and_hidden, + ffn_params, + router_params, +) + + +def moe_active_and_total( + emb_dim, + hidden_dim, + ffn_type, + num_experts, + top_k, + match_dense=True, +): + if match_dense: + dense_params = ffn_params(emb_dim, hidden_dim, ffn_type) + router = router_params(emb_dim, num_experts) + if dense_params <= router: + match_dense = False + + stats = estimate_params_and_hidden( + emb_dim=emb_dim, + hidden_dim=hidden_dim, + ffn_type=ffn_type, + num_experts=num_experts, + match_dense=match_dense, + ) + + active = stats["router"] + top_k * stats["per_expert_params"] + return active, stats["moe_total"] + + +def plot_active_params_vs_experts( + emb_dim, + hidden_dim, + ffn_type="swiglu", + top_k=2, + max_experts=512, + y_log=True, + save_path=None, + match_dense=True, +): + experts = [1, 2, 4, 8, 16, 32, 64, 128, 192, 256, 384, 512] + experts = [e for e in experts if e <= max_experts] + + dense_active = ffn_params(emb_dim, hidden_dim, ffn_type) + moe_active = [] + moe_total = [] + for e in experts: + active, total = moe_active_and_total( + emb_dim=emb_dim, + hidden_dim=hidden_dim, + ffn_type=ffn_type, + num_experts=e, + top_k=top_k, + match_dense=match_dense, + ) + moe_active.append(active) + moe_total.append(total) + + plt.figure(figsize=(7, 5)) + plt.plot(experts, moe_active, marker="o", label="MoE active per token") + plt.plot(experts, moe_total, marker="s", linestyle="--", label="MoE total parameters") + plt.axhline(dense_active, linestyle=":", color="gray", + label="FFN dense (active = total)") + + plt.xlabel(f"Number of experts (top_k = {top_k})") + plt.ylabel("Parameters") + if y_log: + plt.yscale("log") + plt.title( + f"Active vs Total Parameters per Token\n" + f"(emb_dim={emb_dim}, hidden_dim={hidden_dim}, ffn={ffn_type}, top_k={top_k})" + ) + plt.legend() + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=200) + print(f"Saved plot to {save_path}") + else: + plt.show() + + +def main(): + p = argparse.ArgumentParser(description="Plot Dense vs MoE active parameters.") + p.add_argument("--emb_dim", type=int, required=True, help="Embedding dimension") + p.add_argument("--hidden_dim", type=int, required=True, help="Dense FFN hidden size") + p.add_argument("--ffn_type", choices=["gelu", "swiglu"], default="swiglu") + p.add_argument("--top_k", type=int, default=2, help="Active experts per token") + p.add_argument("--max_experts", type=int, default=512, help="Max experts on x-axis") + p.add_argument("--no_log", action="store_true", help="Disable log-scale y-axis") + p.add_argument("--save", type=str, default=None, help="Optional path to save PNG") + p.add_argument( + "--no_match_dense", + action="store_true", + help=("Disable matching MoE parameters to dense FFN total; " + "uses provided hidden_dim instead."), + ) + args = p.parse_args() + + plot_active_params_vs_experts( + emb_dim=args.emb_dim, + hidden_dim=args.hidden_dim, + ffn_type=args.ffn_type, + top_k=args.top_k, + max_experts=args.max_experts, + y_log=not args.no_log, + save_path=args.save, + match_dense=not args.no_match_dense, + ) + + +if __name__ == "__main__": + main() diff --git a/ch04/README.md b/ch04/README.md index ea5b551d4..e5daac6d1 100644 --- a/ch04/README.md +++ b/ch04/README.md @@ -10,6 +10,7 @@ - [02_performance-analysis](02_performance-analysis) contains optional code analyzing the performance of the GPT model(s) implemented in the main chapter - [03_kv-cache](03_kv-cache) implements a KV cache to speed up the text generation during inference +- [07_moe](07_moe) explanation and implementation of Mixture-of-Experts (MoE) - [ch05/07_gpt_to_llama](../ch05/07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI (it might be interesting to look at alternative architectures after completing chapter 4, but you can also save that for after reading chapter 5) diff --git a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb index 396397061..56e29e7fa 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb @@ -153,11 +153,11 @@ " self.emb_dim = cfg[\"emb_dim\"]\n", " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_hidden_dim\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", " for _ in range(cfg[\"num_experts\"])])\n", "\n", " def forward(self, x):\n", @@ -550,7 +550,7 @@ " \"dtype\": torch.bfloat16,\n", " \"num_experts\": 128,\n", " \"num_experts_per_tok\": 8,\n", - " \"moe_intermediate_size\": 768,\n", + " \"moe_hidden_dim\": 768,\n", "}" ] }, @@ -1223,7 +1223,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/ch05/11_qwen3/standalone-qwen3-moe.ipynb b/ch05/11_qwen3/standalone-qwen3-moe.ipynb index bb2de06f9..fd187348e 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe.ipynb @@ -136,7 +136,7 @@ " super().__init__()\n", " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", - " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dt∂ype=cfg[\"dtype\"], bias=False)\n", "\n", " def forward(self, x):\n", " x_fc1 = self.fc1(x)\n", @@ -153,11 +153,11 @@ " self.emb_dim = cfg[\"emb_dim\"]\n", " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_hidden_dim\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", " for _ in range(cfg[\"num_experts\"])])\n", "\n", " def forward(self, x):\n", @@ -492,7 +492,7 @@ " \"dtype\": torch.bfloat16,\n", " \"num_experts\": 128,\n", " \"num_experts_per_tok\": 8,\n", - " \"moe_intermediate_size\": 768,\n", + " \"moe_hidden_dim\": 768,\n", "}" ] }, @@ -1140,7 +1140,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.13.5" } }, "nbformat": 4, From 3184df09f188631644d3e1cbc7637a5c91236b84 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sun, 19 Oct 2025 21:43:54 -0500 Subject: [PATCH 2/4] Apply suggestion from @rasbt --- ch05/11_qwen3/standalone-qwen3-moe.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ch05/11_qwen3/standalone-qwen3-moe.ipynb b/ch05/11_qwen3/standalone-qwen3-moe.ipynb index fd187348e..e9a8f22f3 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe.ipynb @@ -136,7 +136,7 @@ " super().__init__()\n", " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", - " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dt∂ype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", "\n", " def forward(self, x):\n", " x_fc1 = self.fc1(x)\n", From 4ebeb84e1edfa15ab3e73bcd7a07de4592614ebf Mon Sep 17 00:00:00 2001 From: rasbt Date: Sun, 19 Oct 2025 21:51:04 -0500 Subject: [PATCH 3/4] Mixture-of-Experts intro --- ch04/07_moe/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ch04/07_moe/README.md b/ch04/07_moe/README.md index 440adb2f8..6702dedc4 100644 --- a/ch04/07_moe/README.md +++ b/ch04/07_moe/README.md @@ -100,7 +100,7 @@ uv run plot_memory_estimates_moe.py \ The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_moe.py](gpt_with_kv_moe.py) scripts in this folder provide hands-on examples for comparing the regular FFN and MoE memory usage in the context of a GPT model implementation. Note that both scripts use [SwiGLU](https://arxiv.org/abs/2002.05202) feed-forward modules as shown in the first figure of this page (GPT-2 traditionally uses GELU). -**Note: The model is not trained and thus generates nonsensical text. You can find a trained MoE in the bonus materials at [../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb](../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb).** +**Note: The model is not trained and thus generates nonsensical text. You can find a trained MoE in the bonus materials at [../../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb](../../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb).** From f2870cb10d7e0abcd7c631790ef96a5c7210ef2c Mon Sep 17 00:00:00 2001 From: rasbt Date: Sun, 19 Oct 2025 22:02:03 -0500 Subject: [PATCH 4/4] update --- ch04/07_moe/README.md | 4 ++-- .../{ffn_moe_memory_estimator.py => memory_estimator_moe.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename ch04/07_moe/{ffn_moe_memory_estimator.py => memory_estimator_moe.py} (100%) diff --git a/ch04/07_moe/README.md b/ch04/07_moe/README.md index 6702dedc4..73d8e9e87 100644 --- a/ch04/07_moe/README.md +++ b/ch04/07_moe/README.md @@ -48,7 +48,7 @@ When using an MoE layer, only `top_k` experts are active per token, so the effec You can use the [memory_estimator_moe.py](memory_estimator_moe.py) script in this folder to apply this for different model configs to see how much memory you can save by using MoE over FFN (note that this is for a single transformer block, to get the total savings, multiply by the number of transformer blocks in your model): ```bash -uv run ffn_moe_memory_estimator.py --emb_dim 7168 --hidden_dim 14336 --ffn_type swiglu \ +uv run memory_estimator_moe.py --emb_dim 7168 --hidden_dim 14336 --ffn_type swiglu \ --num_experts 8 --top_k 2 --match_dense ==== Config ==== emb_dim : 7168 @@ -98,7 +98,7 @@ uv run plot_memory_estimates_moe.py \   ## MoE Code Examples -The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_moe.py](gpt_with_kv_moe.py) scripts in this folder provide hands-on examples for comparing the regular FFN and MoE memory usage in the context of a GPT model implementation. Note that both scripts use [SwiGLU](https://arxiv.org/abs/2002.05202) feed-forward modules as shown in the first figure of this page (GPT-2 traditionally uses GELU). +The [gpt_with_kv_moe.py](gpt_with_kv_moe.py) and [gpt_with_kv_moe.py](gpt_with_kv_moe.py) scripts in this folder provide hands-on examples for comparing the regular FFN and MoE memory usage in the context of a GPT model implementation. Note that both scripts use [SwiGLU](https://arxiv.org/abs/2002.05202) feed-forward modules as shown in the first figure of this page (GPT-2 traditionally uses GELU). **Note: The model is not trained and thus generates nonsensical text. You can find a trained MoE in the bonus materials at [../../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb](../../ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb).** diff --git a/ch04/07_moe/ffn_moe_memory_estimator.py b/ch04/07_moe/memory_estimator_moe.py similarity index 100% rename from ch04/07_moe/ffn_moe_memory_estimator.py rename to ch04/07_moe/memory_estimator_moe.py