Skip to content

[WIP] Feat/gpt oss example#63

Draft
ymwangg wants to merge 7 commits into
mainfrom
feat/gpt-oss-example
Draft

[WIP] Feat/gpt oss example#63
ymwangg wants to merge 7 commits into
mainfrom
feat/gpt-oss-example

Conversation

@ymwangg

@ymwangg ymwangg commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Add p-eagle gpt-oss-20b example.

ymwangg and others added 7 commits June 25, 2026 13:38
Add a NKIPy example for OpenAI's gpt-oss MoE models (gpt-oss-20b / 120b),
mirroring the qwen3 example structure. The implementation is fully
config-driven, so both sizes share one codebase.

gpt-oss-specific handling:
- MXFP4 experts dequantized to bf16 at prep time
- interleaved gate/up de-interleaved at prep time
- clamped SwiGLU with gate_up/down biases
- per-head attention sinks + QKV/O biases (no QK-norm)
- alternating sliding-window / full attention (one kernel per type)
- YaRN RoPE (inv_freq precomputed from HF config)
- router with top-k-then-softmax and router bias

Validated against HF on trn2 (TP=4): every generated token matches HF's
argmax or a bf16-resolution tie.
Implements parallel-drafting P-EAGLE (arXiv 2602.01469) on top of the
gpt-oss base model for speculative decoding on Trainium.

Components added (examples/models/gpt_oss/eagle/):
- config.py: EagleConfig for the 4-layer P-EAGLE drafter (llama3 RoPE,
  fc fusion, mask_hidden/ptd_token_id, d2t vocab map)
- tensor_preparation.py: convert P-EAGLE checkpoint to x@W form (replicated)
- kernels/drafter.py: parallel-drafting forward - K tokens in one pass via
  NTP (real hidden) + MTP (mask_hidden) positions with cross-depth mask
- kernels/drafter_layer.py: EAGLE-3 fusion midlayer + plain Llama layers
- kernels/verify.py: multi-position greedy argmax for verification
- drafter_model.py: device-side drafter model + compile
- speculate.py: full speculation loop (prefill → draft → verify → accept)

Base model changes:
- config.py: added aux_layers config + default_aux_layers() for EAGLE-3 taps
- gpt_oss.py: run_prefill() now optionally captures pre-layer hidden states
  at the 3 EAGLE-3 tap layers (2, L/2, L-3)
- kernels/attention.py: generalized decode path to support seq_len>1 (for
  the multi-token verify pass) via query_pos = start_pos + arange(seq_len)

Status: functionally correct (lossless greedy output verified against HF).
Acceptance length is below the paper's reported ~3.3 — under investigation
(likely a hidden-state position/timing issue in the draft-verify loop seeding).
…yers

Switch aux capture to post-layer (output of tap layers 2/12/21) based on
HF validation showing the drafter predicts correctly with HF's hidden
states at hs[3]/hs[13]/hs[22] (output of layers 2/12/21).

Note: acceptance length remains low (~1.0) due to numerical divergence
between nkipy's Neuron-compiled target and the HF CPU reference the
drafter was trained against. The drafter kernel is mathematically correct
(validated against independent torch reference) and correctly predicts
the target when fed exact HF hidden states. The gap is an
implementation-coupling issue inherent to EAGLE-style speculation.
Key findings from the P-EAGLE paper (Figure 2, Figure 3, Section 3):

1. The drafter maintains its own KV cache across the full context
   (prompt + all accepted tokens). At each draft step, K positions
   attend to the FULL accumulated cache.

2. The attention mask is GROUP-CAUSAL: all K positions see the full
   cache (group 0), but within the K positions the NTP (group 1)
   and MTP (group 2+) positions use cross-depth causality — MTP
   positions cannot attend to positions at the same or later depth.

3. The NTP pair is (emb(t_n), hidden_after_processing_t_{n-1}),
   predicting t_{n+1}. The hidden is one step behind the embedding.

This commit adds:
- drafter_cpu.py: CPU reference drafter with full KV cache and
  standard causal attention (working infrastructure, mask needs
  the group-causal refinement for MTP positions)
- Fixes hidden state capture to post-layer (output of tap layers)
- Adds peagle_aux_layers config method

Status: KV cache infrastructure correct, still needs the group-causal
mask refinement for the MTP positions within the K-wide draft window.
Root-caused the low acceptance length (~1.4 vs the card's 3.30-3.80 at K=7)
on GPU by running the identical checkpoint through vLLM's eagle3
parallel-drafting path, capturing its drafter I/O, and reproducing it with a
standalone PyTorch reference (cosine 0.9999, 100% draft-token match). Three
bugs plus a prompt-formatting issue:

1. Context-blind drafting (dominant): speculate.py drove DrafterModel
   (kernels/drafter.py), which runs only the K draft positions under a (K,K)
   cross-depth mask with no prefill and no KV cache, so the MTP slots never
   saw the prompt. Rewired speculate.py to use the KV-cached DrafterCPU:
   prefill the drafter on the prompt (EAGLE +1 shift), then each step roll the
   cache back to the last accepted position and run [newly-accepted tokens |
   K-1 ptd slots] in one parallel forward attending to the full context.

2. rollback() truncated the wrong axis: the cache is (B, n_kv, seq, head_dim)
   and rollback sliced dim 1 (n_kv) instead of dim 2 (seq), so rejected
   speculative KV was never discarded and corrupted later steps.

3. Aux tap off-by-one: vLLM's eagle3 default (2, n//2, n-3) captures the
   residual stream entering those layers; our post-layer capture must shift
   down one, so default_aux_layers now returns (1, 11, 20). Verified on GPU
   the drafter's 3 fc chunks equal target layer outputs (1, 11, 20) at cos 1.0.

4. Prompt formatting: the drafter is trained on chat data; raw prompts roughly
   halve acceptance (GPU, K=7: 3.65 chat vs 1.99 raw). speculate.py now applies
   the chat template by default (--raw-prompt to opt out).

Still produces all K draft tokens in a single forward pass (parallel drafting).
Adds test_drafter_cpu.py guarding the rollback/full-context invariants (skips
without the checkpoint). Validated against vLLM on GPU; not yet re-validated on
Trainium, and the on-device kernels/drafter.py KV-cache port remains follow-up.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant