Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions examples/models/gpt_oss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# gpt-oss on Trainium

A clean implementation of OpenAI's gpt-oss MoE models (e.g., `gpt-oss-20b`) for
AWS Trainium, built on NKIPy.

## Setup

``` sh
cd nkipy
uv sync --all-groups
source .venv/bin/activate
cd examples/models/gpt_oss
```

## Quickstart

`test.sh` handles weight preparation and runs a generation end-to-end:

``` sh
./test.sh
```

Or run generation directly (assumes weights are already prepared):

``` sh
WEIGHTS=./tmp_gpt-oss-20b
TP=4

torchrun --nproc-per-node $TP gpt_oss.py \
-n 500 --checkpoint $WEIGHTS --model openai/gpt-oss-20b \
"The capital of France is"
```

You can point `--model` at a local checkpoint directory too.

## Weight preparation

gpt-oss ships its experts **MXFP4-quantized** (`*_blocks` / `*_scales`). The prep
step dequantizes them to bf16 so the NKI kernels run purely in bf16, and it
shards every tensor for tensor parallelism:

``` sh
python tensor_preparation.py \
--model-name openai/gpt-oss-20b \
--world-size 4 \
--output-dir ./tmp_gpt-oss-20b
```

This writes `shard_{rank}.safetensors` files. Dequantized bf16 weights are
larger than the packed checkpoint (~40 GB total), so make sure you have disk
headroom.

## Architecture notes

gpt-oss differs from the Qwen3 MoE example in several ways, all handled here:

| Feature | Handling |
|---|---|
| MXFP4 experts | Dequantized to bf16 at prep time (`tensor_preparation.py`) |
| Interleaved gate/up | De-interleaved to `[gate \| up]` at prep time |
| Clamped SwiGLU | `(up+1) * gate*sigmoid(alpha*gate)` with `clamp(limit=7)` (`kernels/feedforward.py`) |
| Attention sinks | Per-head sink logit concatenated into softmax, then dropped (`kernels/attention.py`) |
| QKV / output bias | Carried through prep and added in the attention kernel |
| Sliding-window attention | Alternating sliding (window=128) / full causal layers; one kernel compiled per attention type |
| YaRN RoPE | `inv_freq` + attention-scaling precomputed from HF config (`config.py`) and baked into the cos/sin cache |
| Router | top-k on raw logits (+bias), then softmax over the selected logits |

## Files

| File | Purpose |
|---|---|
| `gpt_oss.py` | Model definition (`GptOssModel`) and text generation |
| `config.py` | Model configuration (incl. YaRN RoPE precompute) |
| `tensor_preparation.py` | Dequantize, reshape, and shard HF weights for TP |
| `test.sh` | Smoke test: prepares weights and runs generation |
| `kernels/` | Attention, feed-forward, RoPE, RMSNorm, softmax, sampling kernels |
Empty file.
89 changes: 89 additions & 0 deletions examples/models/gpt_oss/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from dataclasses import dataclass

import numpy as np
import torch.distributed as dist
from neuronxcc.nki.language import bfloat16
from transformers import AutoConfig
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS

# to control compiler_args
DTYPE = bfloat16


@dataclass
class Config:
hidden_size: int
num_heads: int
head_dim: int
num_kv_heads: int
num_layers: int
num_experts_per_tok: int
num_experts: int
# RoPE (YaRN) inverse frequencies and post-scaling, precomputed from HF.
rope_inv_freq: np.ndarray
rope_attention_scaling: float
# Per-layer attention type: "sliding_attention" or "full_attention".
layer_types: list
sliding_window: int
# Clamped-SwiGLU parameters (gpt-oss specific).
swiglu_alpha: float = 1.702
swiglu_limit: float = 7.0
context_len: int = None
max_new_tokens: int = None
max_batch_size: int = 1
norm_eps: float = 1e-5
intermediate_size: int = 2880
max_seq_len: int = 4096
dtype: np.dtype = DTYPE
additional_compiler_args_nkipy: str = "--lnc 1"
# Decoder-layer indices whose outputs are tapped for EAGLE-3 speculative
# decoding. None disables capture (the default, non-speculative path).
aux_layers: tuple = None

def is_sliding(self, layer_id: int) -> bool:
return self.layer_types[layer_id] == "sliding_attention"

@staticmethod
def default_aux_layers(num_layers: int) -> tuple:
"""EAGLE-3's low/mid/high taps as the *outputs* of these decoder layers.

vLLM's EAGLE-3 default is ``(2, n//2, n-3)`` but it captures the residual
stream *entering* those layers (``_maybe_add_hidden_state`` runs with
``layer_idx=i+1`` after layer ``i``). Our prefill loop captures *after*
running layer ``i``, so to tap the same values we shift down by one:
the input of layer ``L`` is the output of layer ``L-1``. Verified on GPU
against vLLM for gpt-oss-20b: the drafter's 3 fc chunks equal the outputs
of target layers (1, 11, 20) at cosine 1.0 (see eagle README).
"""
return (2 - 1, num_layers // 2 - 1, num_layers - 3 - 1)


def get_config(model_name, context_len, max_new_tokens):
hf_config = AutoConfig.from_pretrained(model_name)

# YaRN RoPE: precompute inverse frequencies + attention scaling factor once.
# These are constants (independent of runtime tensors), so we bake them into
# the kernel's cos/sin cache at compile time.
rope_init_fn = ROPE_INIT_FUNCTIONS[hf_config.rope_parameters["rope_type"]]
inv_freq, attention_scaling = rope_init_fn(hf_config, device=None)

config = Config(
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.intermediate_size // dist.get_world_size(),
num_heads=hf_config.num_attention_heads,
head_dim=hf_config.head_dim,
num_kv_heads=hf_config.num_key_value_heads,
norm_eps=hf_config.rms_norm_eps,
num_layers=hf_config.num_hidden_layers,
num_experts_per_tok=hf_config.num_experts_per_tok,
num_experts=hf_config.num_local_experts,
rope_inv_freq=np.asarray(inv_freq, dtype=np.float32),
rope_attention_scaling=float(attention_scaling),
layer_types=list(hf_config.layer_types),
sliding_window=hf_config.sliding_window,
swiglu_alpha=getattr(hf_config, "swiglu_alpha", 1.702),
swiglu_limit=hf_config.swiglu_limit,
context_len=context_len,
max_new_tokens=max_new_tokens,
)
return config
206 changes: 206 additions & 0 deletions examples/models/gpt_oss/eagle/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# P-EAGLE Speculative Decoding for gpt-oss on Trainium

Parallel-drafting speculative decoding using [P-EAGLE](https://arxiv.org/abs/2602.01469)
for the gpt-oss model family on AWS Trainium. Generates K draft tokens in a
**single forward pass** (not K sequential passes), then verifies them against the
target in one multi-token target forward.

## Setup

``` sh
cd nkipy
uv sync --all-groups
source .venv/bin/activate
cd examples/models/gpt_oss
```

## Quickstart

### 1. Prepare weights

The target model (gpt-oss-20b) must already be prepared (see `../README.md`):

``` sh
# Target (if not already done)
python tensor_preparation.py \
--model-name /path/to/gpt-oss-20b \
--world-size 4 --output-dir ./tmp_gpt-oss-20b

# Drafter (P-EAGLE, replicated on every rank — small, ~3.6 GB)
python eagle/tensor_preparation.py \
--model-name /path/to/GPT-OSS-20B-P-EAGLE \
--output-dir ./eagle/tmp_p-eagle
```

### 2. Run speculative decoding

``` sh
TP=4
torchrun --nproc-per-node $TP eagle/speculate.py \
--target-checkpoint ./tmp_gpt-oss-20b \
--draft-checkpoint ./eagle/tmp_p-eagle \
--model /path/to/gpt-oss-20b \
--draft-model /path/to/GPT-OSS-20B-P-EAGLE \
-n 256 -k 7 \
"Write a Python function that implements binary search."
```

Output includes acceptance metrics:

```
Time to first token: 0.6s
Generated 256 tokens in N verify steps
Mean acceptance length: X.XX (K=7)
Decode tokens/sec: XX.XX
```

## How it works

### Speculation loop

```
1. Target prefill on prompt → first token + 3 tapped hidden states
2. Drafter prefill on prompt (EAGLE shift) → drafter KV cache over positions 0..P-2
3. Loop:
a. Drafter: roll cache back to last accepted pos; run
[newly-accepted tokens | K-1 ptd slots] in ONE parallel forward pass,
attending to the FULL drafter KV cache → K draft tokens
b. Target verify: run [last_accepted, draft_0, ..., draft_{K-1}] through
target layers (seq_len = K+1) with block-causal mask
c. Accept: longest prefix where draft[i] == target_argmax[i]
d. Emit accepted tokens + bonus correction token
e. Advance KV cache position by (accepted + 1)
```

### P-EAGLE parallel drafting (K tokens in one pass)

Unlike autoregressive EAGLE which runs K sequential drafter passes, P-EAGLE
generates all K draft tokens simultaneously. Per step, the slots appended to the
drafter's KV cache (all attending to the full prior context) are:

| Slot | Embedding input | Hidden state input |
|------|----------------|-------------------|
| committed (incl. NTP, depth 0) | `embed(newly-accepted token)` | `fc(concat(aux_layer_1, aux_layer_11, aux_layer_20))` — real target hidden |
| K-1 MTP (depth 1..K-1) | `embed(ptd_token_id)` — placeholder | `fc(mask_hidden)` — learnable shared hidden |

The K draft logits come from the **last committed slot** (NTP) plus the K-1 ptd
slots (MTP), each through the EAGLE-3 fusion midlayer + 3 plain Llama decoder
layers. All positions attend causally (over absolute positions) to the full
context KV cache — not just to each other.

### Architecture details

The P-EAGLE drafter (`GPT-OSS-20B-P-EAGLE`, ~3.6 GB bf16):

| Component | Description |
|-----------|-------------|
| `fc` (8640→2880) | Fuses 3 target hidden states (layers 2, 12, 21 of 24-layer target) |
| `midlayer` | EAGLE-3 fusion decoder layer: attention takes 2×hidden (embed⊕hidden), has `hidden_norm` |
| `layers.1/2/3` | Plain Llama decoder layers (SiLU MLP, llama3 RoPE) |
| `mask_hidden` (1,1,8640) | Learnable shared hidden state for MTP positions |
| `ptd_token_id` = 201020 | Placeholder token whose embedding fills MTP positions |
| `d2t` / `t2d` | Draft↔target vocab mapping (identity for this checkpoint) |
| `lm_head` (2880→201088) | Full target vocab, replicated on every rank |

### Verification

The target verifies K+1 candidate tokens in a single multi-token forward pass:
- Runs the full gpt-oss decoder stack with `seq_len = K+1` at a runtime offset
- Uses absolute-position RoPE and a block-causal attention mask
- Writes K+1 new KV cache entries contiguously
- Produces per-position greedy argmax via cross-rank reduction

**Greedy acceptance makes KV rollback implicit**: rejected speculative entries are
overwritten by the next verify pass, and the causal mask prevents any query from
attending past its own position.

## Files

| File | Purpose |
|------|---------|
| `speculate.py` | Main entry: speculation loop orchestrating target + drafter |
| `config.py` | `EagleConfig` for the P-EAGLE drafter (llama3 RoPE, fc, mask_hidden, K) |
| `tensor_preparation.py` | Convert P-EAGLE checkpoint to x@W form (replicated, no TP) |
| `drafter_model.py` | Device-side drafter: loads weights, compiles kernel, runs draft |
| `kernels/drafter.py` | Parallel-drafting forward kernel (K tokens in one pass) |
| `kernels/drafter_layer.py` | EAGLE-3 fusion midlayer + plain Llama layers |
| `kernels/verify.py` | Multi-position greedy argmax for verification |
| `kernels/rope.py` | llama3 RoPE (different from target's YaRN RoPE) |
| `kernels/rmsnorm.py`, `softmax.py` | Leaf kernels (copied from base) |

## Validation

| What | Result |
|------|--------|
| Drafter layer math (`DrafterCPU`) | ✅ cos 0.9999 vs vLLM on prefill (85/86 tokens) |
| KV-cached decode steps | ✅ cos 0.9999, 100% draft-token match vs vLLM |
| `draft()` public API | ✅ 7/7 draft tokens match vLLM on multiple decode steps |
| Aux tap layers (1, 11, 20) | ✅ fc chunks equal target layer outputs at cos 1.0 |
| Rollback / context-attention | ✅ guarded by `test_drafter_cpu.py` |

## Acceptance length: root cause & fix

Earlier acceptance was ~1.4 tokens/step (vs the model card's 3.30–3.80 at K=7).
This was root-caused on GPU by running the **identical checkpoint** through vLLM's
`eagle3` parallel-drafting path and capturing its exact drafter I/O, then
reproducing it with a standalone PyTorch reference (cosine **0.9999**, 100%
draft-token match). Three issues were found and fixed:

1. **Context-blind drafting (dominant).** `speculate.py` drove `DrafterModel`
(`kernels/drafter.py`), which runs only the K draft positions under a (K,K)
cross-depth mask with **no prefill and no KV cache** — the MTP slots never saw
the prompt, so they produced generic tokens. The drafter must keep a KV cache
over the whole context and have every new position attend to it (plain causal
over absolute positions). `speculate.py` now uses the KV-cached `DrafterCPU`:
it prefills the drafter on the prompt and, each step, rolls the cache back to
the last accepted position and runs `[newly-accepted tokens | K-1 ptd slots]`.

2. **`rollback()` truncated the wrong axis.** The cache tensors are
`(B, n_kv, seq, head_dim)`; `rollback()` sliced dim 1 (`n_kv`) instead of dim 2
(`seq`), so rejected speculative KV entries were never discarded and corrupted
every later step. Fixed to slice the sequence axis. Guarded by
`test_drafter_cpu.py::test_rollback_restores_clean_cache`.

3. **Aux tap off-by-one.** vLLM's EAGLE-3 default `(2, n//2, n-3)` captures the
residual stream *entering* those layers (`layer_idx=i+1` after layer `i`), i.e.
the **outputs of layers (1, 11, 20)** for the 24-layer target. Our prefill loop
captures *after* layer `i`, so `default_aux_layers` now returns `(1, 11, 20)`.
Verified on GPU: the drafter's 3 fc chunks equal target layer outputs (1,11,20)
at cosine 1.0.

**Prompt formatting matters too.** The drafter is trained on chat data; raw
completion prompts are out-of-distribution and roughly halve acceptance (GPU,
identical checkpoint, K=7: **3.65** chat vs **1.99** raw). `speculate.py` now
applies the chat template by default (`--raw-prompt` to opt out).

### Validated algorithm (matches vLLM)

1. Target taps the residual stream (`x+residual`) at the **outputs of layers
(1, 11, 20)**; concat the 3 → `fc` → H.
2. **EAGLE shift:** drafter slot `p` pairs `embed(token@p+1)` with
`target_hidden@p`.
3. **Drafter prefill** over the prompt builds a KV cache (positions 0..P-2), plain
causal.
4. **Each draft step:** roll the cache back to the last accepted position; append
`[newly-accepted tokens (real target hidden) | K-1 ptd slots (fc(mask_hidden),
ptd_token_id embedding)]` at consecutive absolute positions, attending to the
full cache; the K draft logits are the last committed slot (NTP) + the K-1 ptd
slots (MTP).

### vLLM reference (parallel_drafting)

vLLM produces all K tokens in **one forward pass**: the expanded input is
`[shifted context tokens | bonus (next_token) | K-1 ptd_token positions]`, all run
together. `parallel_drafting_hidden_state_tensor = fc(mask_hidden)` fills the MTP
positions; the `copy_and_expand_eagle_inputs_kernel` lays out sequential positions
(`start_pos + j`) and tags parallel-draft slots with `ptd_token_id`. Only the
midlayer concatenates embeds with hidden to 2H; later layers are standard.

### Status / remaining work

The CPU drafter path (`DrafterCPU`) and the `speculate.py` loop bookkeeping are
GPU-validated against vLLM. Not yet re-validated on Trainium hardware: the NKI
target's `verify`/prefill numerics and end-to-end acceptance. The on-device
`kernels/drafter.py` still lacks a KV cache and is currently unused by
`speculate.py`; porting the validated KV-cached path to the device kernel is the
remaining performance work.
Empty file.
Loading