Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pip install safetensors torch
python scripts/convert_nemo.py parakeet-tdt_ctc-110m.nemo -o model.safetensors
```

The converter supports all model types: `110m-tdt-ctc` (default), `600m-tdt`, `eou-120m`, `nemotron-600m`, `sortformer`.
The converter auto-detects the model type by default. It also supports explicit model types: `110m-tdt-ctc`, `600m-tdt`, `eou-120m`, `nemotron-600m`, `sortformer`.

```bash
python scripts/convert_nemo.py checkpoint.nemo -o model.safetensors --model 600m-tdt
Expand Down
53 changes: 48 additions & 5 deletions scripts/convert_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import argparse
import re
import tarfile
import tempfile
import sys
Expand Down Expand Up @@ -93,6 +94,44 @@
NUM_DURATIONS = 5


def infer_model_type(state_dict):
"""Infer the model preset from checkpoint tensor shapes.

This keeps the common path safe: converting a 600M checkpoint with the
110M default silently produces loadable but invalid joint weights.
"""
if any(k.startswith("sortformer_modules.") for k in state_dict):
return "sortformer"

embed = state_dict.get("decoder.prediction.embed.weight")
if embed is None:
raise ValueError(
"could not infer model type: missing decoder.prediction.embed.weight; "
"pass --model explicitly"
)

vocab_size = int(embed.shape[0])
layer_indices = set()
for key in state_dict:
match = re.match(r"encoder\.layers\.(\d+)\.", key)
if match:
layer_indices.add(int(match.group(1)))
num_layers = len(layer_indices)

if vocab_size == 8193:
return "600m-tdt"
if vocab_size == 1027:
return "eou-120m"
if vocab_size == 1025 and num_layers == 17:
return "110m-tdt-ctc"

raise ValueError(
"could not infer model type from checkpoint "
f"(vocab_size={vocab_size}, encoder_layers={num_layers}); "
"pass --model explicitly"
)


# ─── NeMo → Axiom name mapping ──────────────────────────────────────────────

def build_subsampling_map(axiom_prefix="encoder_"):
Expand Down Expand Up @@ -384,8 +423,13 @@ def dump_keys(ckpt_path):
print(f" {key:70s} {list(t.shape)}")


def convert(ckpt_path, output_path, model_type=DEFAULT_MODEL):
def convert(ckpt_path, output_path, model_type="auto"):
"""Convert NeMo checkpoint to axiom safetensors."""
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if model_type == "auto":
model_type = infer_model_type(state_dict)
print(f"Auto-detected model type: {model_type}")

preset = MODEL_PRESETS[model_type]
vocab_size = preset["vocab_size"]
num_durations = preset["num_durations"]
Expand All @@ -397,7 +441,6 @@ def convert(ckpt_path, output_path, model_type=DEFAULT_MODEL):
print(f" Encoder layers: {preset['num_layers']}, vocab: {vocab_size}, "
f"LSTM layers: {num_lstm_layers}, CTC: {preset.get('has_ctc', False)}")

state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
mapping = build_full_mapping(preset)

output = {}
Expand Down Expand Up @@ -517,9 +560,9 @@ def main():
help="Output safetensors file (default: model.safetensors)")
parser.add_argument("--dump", action="store_true",
help="Just dump checkpoint keys and shapes")
parser.add_argument("--model", choices=list(MODEL_PRESETS.keys()),
default=DEFAULT_MODEL,
help=f"Model type (default: {DEFAULT_MODEL})")
parser.add_argument("--model", choices=["auto"] + list(MODEL_PRESETS.keys()),
default="auto",
help="Model type (default: auto-detect from checkpoint shapes)")
args = parser.parse_args()

ckpt_path = extract_checkpoint(args.input)
Expand Down