Skip to content

Conversation

@tscholak
Copy link
Collaborator

@tscholak tscholak commented Nov 27, 2025

This PR enables declarative architecture surgery on Apriel2 models. You can now convert a standard attention-based model into exotic architectures (stochastic supernets, hybrid Mamba/GDN models) with a single command and a YAML config.

To convert Apriel-1.5-15b-Thinker into a stochastic supernet with 3 mixer types per layer:

python convert.py ServiceNow-AI/Apriel-1.5-15b-Thinker /tmp/apriel2-supernet \
    --surgery examples/stochastic_supernet.yaml --verbose

The converter shows you exactly what it's doing with a visual plan tree:

============================================================
CONVERSION PLAN
============================================================
├── model/
│   ├── [embed_tokens, norm]/
│   │   └── weight ← language_model.model.[embed_tokens, norm].weight
│   ├── vision_encoder/
│   │   └── encoder/blocks/[0..23]/...
│   └── decoder/
│       └── blocks/
│           └── [0..47]/
│               ├── mixer/
│               │   └── mixers/
│               │       ├── attention/
│               │       │   └── self_attn/[q,k,v,o]_proj/weight ← ...self_attn.[q,k,v,o]_proj.weight
│               │       ├── sliding_window/
│               │       │   └── self_attn/[q,k,v,o]_proj/weight ← ...self_attn.[q,k,v,o]_proj.weight
│               │       └── gated_delta_net/
│               │           └── gdn/
│               │               ├── in_proj_qkvz/weight = [Q[:160,:]; K[:160,:]; V slices...; 𝟎(640×5120)]
│               │               ├── in_proj_ba/weight = 𝟎(64×5120)
│               │               ├── out_proj/weight ← ...self_attn.o_proj.weight
│               │               ├── conv1d/weight = scaled_identity_conv(7680×1×4)
│               │               ├── A_log = A_log(32)
│               │               └── norm/weight = 𝟏(160)
│               ├── mlp/[gate,up,down]_proj/weight ← ...
│               └── [input_layernorm, post_attention_layernorm]/weight ← ...
└── lm_head/weight ← language_model.lm_head.weight
============================================================
Converting: 100%|████████████████████████████████| 1185/1185 [00:40<00:00]
Conversion complete! Output saved to /tmp/apriel2-supernet

The result is a sharded Apriel2 checkpoint with attention + sliding window + gated delta net mixers in every layer.

The surgery config for this looks like this:

decoder:
  type: fixed
  block:
    mixer:
      type: stochastic
      main_mixer_name: attention
      sampling_strategy: uniform
      mixers:
        # Main attention mixer - inherits config and weights from source
        attention:
          type: attention
          init: transfer

        # Sliding window - same architecture with window size override
        sliding_window:
          type: attention
          init: transfer
          sliding_window: 4096

        # Gated delta net - DIL initialization maps Q/K/V/O -> GDN projections
        gated_delta_net:
          type: gated_delta_net
          init: transfer
          conv_kernel_size: 4  # Only required param - rest derived from source

    # MLP and normalization transfer from source
    mlp:
      init: transfer

    normalization:
      init: transfer

The converter supports the following better-than-random initializations for linear mixers:

  • MIL: Q→C, K→B, V→x, O→out_proj
  • DIL: Q/K/V→in_proj_qkvz with GQA-aware tiling

tscholak and others added 4 commits November 27, 2025 19:24
- Rename Apriel2CheckpointFormat to Apriel2TextCheckpointFormat for text-only models
- Add new Apriel2CheckpointFormat for multimodal models (tabled for now)
- Replace num_hidden_layers with num_blocks in decoder config (Fast-LLM convention)
- Update test fixtures to use num_blocks in decoder configs
- Fix stochastic mixer preprocess() to collect attention_mask from nested mixers
- Add cache initialization to Apriel2GatedDeltaNet for lazy allocation
- Use past_key_values (plural) consistently per HuggingFace convention
- Update test code to use model.model.decoder.blocks[idx] accessor

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
…aches

- Test 1: Empty cache vs filled cache - verifies cache is being used at all
- Test 2: Corrupted cache (zeros) vs correct cache - verifies cache VALUES matter
- Derive cache dimensions from actual forward pass (handles different attention configs)
- Fix: original test used wrong attribute names (key_cache/value_cache instead of key/value)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
- Update modeling_apriel2.py to use direct dict access instead of helper
  methods (config.embeddings["max_position_embeddings"] instead of
  config.get_max_position_embeddings())
- Fix activation export in vision adapter converter to use .hf_name
  instead of .value for proper round-trip conversion
- Fix MultiModalInferenceRunner naming in multimodal/config.py
- Raise NotImplementedError for multimodal HF wrapper (not implemented)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
- Multimodal converter: stratified inheritance from Pixtral/LLaVA
  - Inherit get_converters for Attention, Block, Encoder, Adapter (shares weight conversion logic)
  - Standalone PatchConvolutionConverter (different paths, no meaningful sharing)
  - Override all import_config/export_config (different naming and nested structure)
- Remove verbose docstrings and self-narrative comments from all Apriel2 files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@tscholak tscholak force-pushed the tscholak/apriel2-conversion branch from 6d67133 to 98a5d25 Compare November 27, 2025 19:25
tscholak and others added 7 commits November 28, 2025 11:23
Introduces convert_from_llava.py which converts Llava/Pixtral models
(like Apriel 1.5) to Apriel2 format. The converter handles:
- Config conversion from Llava to Apriel2 format
- Weight mapping between different naming conventions
- Vision encoder, projector, and language model weights
- Support for both local paths and HuggingFace model IDs

Test coverage includes:
- Config conversion validation
- Component-level forward pass equivalence (embeddings, vision encoder,
  projector, language model layers)
- Full model forward pass equivalence for text-only inputs
- Multimodal forward pass validation (image + text inputs)
- Apriel 1.5 large model conversion test (marked as slow)

Note: Multimodal numerical equivalence is not possible due to
architectural differences between Pixtral and Apriel2 vision encoders
(Pixtral produces (size/16)^2 - 1 patches vs Apriel2's (size/16)^2).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Refactors the Llava-to-Apriel2 converter to cleanly separate concerns:

1. **convert_from_llava.py** - Pure format conversion (Llava -> Apriel2)
   - Config conversion: 1-to-1 mapping of Llava config to Apriel2 format
   - Weight conversion: Pure name mapping, no transformations
   - No surgery logic - just format translation

2. **surgery.py** - Generic Apriel2 -> Apriel2 transformation
   - Layer-by-layer conversion using converter registry
   - For stochastic mixers, source is always the main mixer
   - Supports wrapping attention with stochastic mixer
   - Random initialization for incompatible conversions (e.g., attention -> mamba)

3. **converters.py** - Converter registry and implementations
   - Identity: forall a. a -> a
   - Bidirectional: attention <-> sliding_window
   - Random init utilities for mamba, attention, gated_delta_net

Benefits:
- Surgery can be applied to ANY Apriel2 model, not just converted ones
- Easy to add new source formats (Qwen, Llama, etc.)
- No intermediate persistence - all operations on in-memory state dicts
- Cleaner code: 725 lines removed in refactor

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
- Add expr_plan.py: declarative weight transformation with composable
  expressions (Ref, Slice, Concat, Init, Reshape) and streaming executor
- Implement MIL (Mamba Initialization from LLM) for attention->mamba surgery
- Remove legacy converters.py and surgery.py (imperative approach)
- Simplify convert_from_llava.py to use plan-based streaming only
- Update tests to use new expr_plan API

The plan system enables:
- Composable conversions via plan composition (Llava->Apriel2->Modified)
- Memory-efficient streaming execution with ref-counting
- Declarative, inspectable transformation plans
- W path builder for readable key construction

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Key changes:
- Add GatedDeltaNet (DIL) conversion from attention weights
- Support stochastic mixer with multiple sub-mixers (attention + mamba/GDN)
- Add dt_init_floor parameter for Mamba dt_bias initialization
- Fix plan tree collapsing to merge layers but not projections
- Add example YAML configs for hybrid architectures

The tree collapsing fix ensures that layers [0..47] are merged at the
blocks level while projections (q_proj, k_proj, etc.) remain separate.
This is achieved by tracking which positions vary within each group
and only allowing merges when the cross-group variation matches.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
- Add SafetensorLoader context manager for O(1) key lookup across sharded files
- Add ShardedSafetensorWriter for streaming output with configurable shard size
- Update convert_from_llava.py to use streaming pipeline
- Bounds peak memory to ~5GB instead of ~30GB for large models

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
…erter

- Split monolithic expr_plan.py into conversion/ subpackage:
  - expr.py: Expression DSL types (Ref, Slice, Concat, Init, Reshape)
  - render.py: Plan rendering and tree visualization
  - executor.py: Plan execution and streaming executor
  - io.py: SafetensorLoader and ShardedSafetensorWriter
  - converters.py: MIL/DIL converters and surgery planning

- Move Llava-specific code into conversion/llava/:
  - config.py: Llava config to Apriel2 config conversion
  - plan.py: Llava to Apriel2 weight plan builder

- Create source-format agnostic convert.py:
  - Registry pattern for source formats (SOURCE_FORMATS dict)
  - Auto-detection via detect_source_format()
  - Generic build_plan() and convert() functions

- Update tests to use new imports and add seed=0 to execute() calls

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
The GDN uses DIL initialization which maps attention Q/K/V/O weights
to GDN projections. Only conv_kernel_size needs to be specified -
other dimensions (num_value_heads, num_key_heads, head dims) are
automatically derived from the source attention config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants