SPDX-FileCopyrightText: 2024 Idiap Research Institute [email protected]
SPDX-FileContributor: Kaloga Yacouba [email protected]
SPDX-FileContributor: Shashi Kumar [email protected]
Implementation of the Optimal Temporal Transport Classification (OTTC) loss introduced in A Differentiable Alignment Framework for Sequence-to-Sequence Modeling via Optimal Transport (Kaloga et al., 2025). OTTC replaces the non-differentiable alignment surrogate used by CTC with a 1-D optimal transport solver that stays fully differentiable, produces sharper alignments, and maintains practical training stability on modern GPUs.
Paper: https://arxiv.org/abs/2502.01588
Questions are welcome at:
- Yacouba Kaloga (
[email protected]) - Shashi Kumar (
[email protected])
- Overview
- Highlights
- Getting Started
- Configuring Paths
- Datasets and Pretrained Encoders
- Running Experiments
- Monitoring and Outputs
- Reproducing the Paper
- Repository Layout
- Citation
- License
OTTC formulates alignment between acoustic features and target symbols as a differentiable optimal transport problem. The project provides:
- Reference PyTorch implementations of OTTC, soft-DTW, and CTC heads on top of Wav2Vec2-style encoders.
- Training and evaluation pipelines for speech recognition benchmarks (LibriSpeech, Multilingual LibriSpeech, AMI, and TIMIT) using the Hugging Face ecosystem.
- Utilities to build vocabularies, collate speech batches, and compute WER for evaluation.
- Differentiable alignments: Fast batched 1-D optimal transport with cross-entropy, Euclidean,
soft-Euclidean, or Jensen–Shannon costs (see
ottc/ops_loss/ot_loss.py). - Baseline losses: Drop-in CTC and soft-DTW heads to benchmark against OTTC.
- Encoder flexibility: Works with Hugging Face checkpoints such as Wav2Vec2-Large-LV60,
XLS-R, Whisper-Small, etc. (
--encoder_name). - Dataset tooling: Ready-to-use loaders, collators, and vocab builders for LibriSpeech, Multilingual LibriSpeech, AMI, and TIMIT (letters, phonemes, fused phonemes).
- Resumable training: Automatic checkpoint discovery when
--continue_if_existis enabled, plus evaluation-only runs for frozen models.
git clone https://github.com/idiap/OTTC.git
cd OTTCThe recommended setup is via Conda (CUDA 12.1, PyTorch 2.3):
conda env create -f env.yaml
conda activate ottcenvAlternatively, create a virtualenv and install dependencies from requirements.txt (ensure that a
CUDA-enabled PyTorch build >= 2.3.0 is already installed):
python -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txtRunning the scripts from the repository root automatically exposes the local ottc package.
If you need to launch Python from elsewhere, export PYTHONPATH=$(pwd) (or the absolute path to
this repository) before invoking the scripts.
All data locations live in ottc/config/path_config.py. A ready-to-fill template with placeholder
paths is available at ottc/config/path_config_to_complete.py—copy it (or edit the existing
path_config.py) and update the values before running experiments:
from pathlib import Path
DATASETS_ROOT_PATH = Path('/your/datasets/root')
HUGGINGFACE_CACHE = Path('/your/hf_cache')
LARGE_MODELS_PATH = Path('/your/pretrained/models')
TIMIT_DATASET_PATH = Path('/your/timit/root') # optional override when TIMIT lives elsewhereThese directories host:
- extracted datasets (LibriSpeech, AMI, etc.)
- the Hugging Face cache for dataset + model downloads
- pretrained encoders and experiment outputs (
large_models_results/). Common Hugging Face checkpoints (e.g.,facebook/wav2vec2-large-lv60,facebook/wav2vec2-large-xlsr-53,openai/whisper-small) are downloaded automatically the first time they are requested via--encoder_nameand cached under${LARGE_MODELS_PATH}for reuse.
Make sure the parent folders exist and are writable. If you work on a shared cluster, you can point
these variables to scratch volumes and export environment variables such as
export LARGE_MODELS_PATH=/scratch/$USER/large_models before launching jobs.
- LibriSpeech / Multilingual LibriSpeech: downloaded via 🤗 Datasets. The loader expects
librispeech_asrandfacebook/multilingual_librispeechunderDATASETS_ROOT_PATHand builds letter vocabularies automatically. - AMI: requires manual download of the AMI IHM files into
${DATASETS_ROOT_PATH}/edinburghcstr/ami(seeottc/tools/dataset_loader/ami.py). - TIMIT: requires a licensed copy. Download it from https://catalog.ldc.upenn.edu/LDC93S1,
extract the archive into
${DATASETS_ROOT_PATH}/timit(or pointTIMIT_DATASET_PATHto the extraction root), and ensure the directory is readable by your jobs. - Pretrained encoders: place checkpoints such as
wav2vec2-large-lv60inside${LARGE_MODELS_PATH}. During training OTTC writes checkpoints and TensorBoard logs under${LARGE_MODELS_PATH}/large_models_results/<encoder>/<suffix>/.
Vocabulary files for each dataset live in vocab/ and are created the first time you run an
experiment.
run_exp.py is the main entry point for training and evaluation. The minimum command looks like:
python run_exp.py \
--dataset librispeech_asr \
--train_name train.100 \
--test_name test \
--loss_type ottc \
--encoder_name wav2vec2-large-lv60 \
--suffix libri_ottc_demoKey arguments:
--datasetselects the loader fromDATASET_LOADERS. Available options includelibrispeech_asr,ami,timit_letter,timit_phoneme,timit_fused_phoneme, andmultilingual_librispeech_<lang>with<lang> ∈ {en, fr, it, de, es}.--train_name/--test_namechoose the split names used inside 🤗 datasets (for LibriSpeech, e.g.train.100,train.360,train.960,test,validation).--loss_type {ctc,ottc,softdtw,softdtweuc,...}toggles the output head; OTTC-specific options are controlled by--bins_metric {crossentropy,euclidian,softeuclidian,jsd}and--forward_type {regular,alternate_scheme,alternate_scheme2,regfrozen}.--prepare_datasetpoints to a function inottc/utils/prepare_dataset.py; common choices areprepare_dataset,prepare_dataset_border, orprepare_dataset_breakdoubleletterfor breaking double letters when using OTTC.--frozen_epochs_for_regfrozensets the number of epochs the alignment stays frozen when using theregfrozenforward pass.--continue_if_exist trueresumes from the most recent checkpoint; combine with--only_evaluate trueto compute WER without further training.--batch_size,--learning_rate,--num_train_epochs,--weight_decay, and--dataloader_num_workersbehave as in Hugging FaceTrainingArguments.
Example workflows:
# CTC baseline on LibriSpeech 100h
# OTTC with reg-frozen alignment on TIMIT fused phonemes
python run_exp.py \
--dataset timit_fused_phoneme \
--train_name train \
--test_name kaldi_test \
--loss_type ottc \
--prepare_dataset prepare_dataset_for_timit_phoneme \
--forward_type regfrozen \
--frozen_epochs_for_regfrozen 30 \
--suffix timit_ottc_fused
# Evaluate the latest checkpoint only
python run_exp.py --only_evaluate true --continue_if_exist true --suffix timit_ottc_fusedInspect all options with python run_exp.py --help.
- Checkpoints: stored in
${LARGE_MODELS_PATH}/large_models_results/<encoder>/<suffix>/. - TensorBoard: logs are written alongside checkpoints; launch
tensorboard --logdiron the same directory to monitor training and alignment plots. - WER computation: handled by
ottc/utils/eval_metrics.pyusing the 🤗wermetric during evaluation. To reproduce the LibriSpeech WER experiment reported in the paper once the OTTC model is trained, run:
python run_exp.py \
--dataset librispeech_asr \
--train_name train.100 \
--test_name test \
--loss_type ottc \
--suffix libri_ottc_100h \
--only_evaluate true \
--continue_if_exist trueTo reproduce experiments reported in the OTTC paper, adapt the following templates to your local paths:
# LibriSpeech (100h ctc)
python run_exp.py \
--dataset librispeech_asr \
--train_name train.100 \
--test_name test \
--loss_type ctc \
--prepare_dataset prepare_dataset \
--batch_size 64 \
--learning_rate 5e-5 \
--num_train_epochs 40 \
--force_lowercase False \
--suffix runexp_ctc_libri_100h
# LibriSpeech (100h ottc)
python run_exp.py \
--dataset librispeech_asr \
--train_name train.100 \
--test_name test \
--loss_type ottc \
--forward_type regfrozen \
--frozen_epochs_for_regfrozen 30 \
--prepare_dataset prepare_dataset_breakdoubleletter \
--batch_size 64 \
--learning_rate 2e-4 \
--num_train_epochs 40 \
--suffix runexp_ottc_libri_100h_rf
# TIMIT (fused phonemes ctc)
python run_exp.py \
--dataset timit_fused_phoneme \
--train_name train \
--test_name kaldi_test \
--loss_type ctc \
--prepare_dataset prepare_dataset_for_timit_phoneme \
--batch_size 64 \
--learning_rate 5e-5 \
--num_train_epochs 40 \
--force_lowercase True \
--suffix runexp_ctc_fuse_phoneme
# TIMIT baseline (fused phonemes ottc)
python run_exp.py \
--dataset timit_fused_phoneme \
--train_name train \
--test_name kaldi_test \
--loss_type ottc \
--forward_type regfrozen \
--frozen_epochs_for_regfrozen 30 \
--prepare_dataset prepare_dataset_for_timit_phoneme \
--batch_size 64 \
--learning_rate 2e-4 \
--num_train_epochs 40 \
--suffix runexp_ottc_fuse_phonemeAMI reproduction currently depends on an internal tokenizer variant and is not bundled in this release. Support will be added in a future update—coming soon alongside alignment-metric recipes.
OTTC/
├── run_exp.py # Main training / evaluation script
├── run_ottc.py # Legacy research playground (older API)
├── ottc/
│ ├── config/path_config.py # Global paths for datasets, cache, pretrained models
│ ├── models/sequence/ # Wav2Vec2 heads for CTC and OTTC
│ ├── ops_loss/ # Optimal transport + soft-DTW losses
│ ├── tools/ # Dataset loaders, collators, vocab builder
│ └── utils/ # Metrics, dataset preprocessing helpers
├── vocab/ # Auto-generated vocabularies per dataset
├── requirements.txt # Python dependency list
└── xfs.yaml # Conda environment specification (CUDA 12.1, PyTorch 2.2)
If you use OTTC in your work, please cite:
@inproceedings{kaloga2025ottc,
title = {A Differentiable Alignment Framework for Sequence-to-Sequence Modeling via Optimal Transport},
author = {Kaloga, Yacouba and
Shashi, Kumar and
Kodrasi, Ina and
Moltieck, Petr},
year = {2025},
eprint = {2502.01588},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
This project is released under the MIT License. See the LICENSE file for details. Third-party
libraries retain their respective licenses (Transformers: Apache 2.0, torchaudio: BSD, POT: MIT).