A PyTorch-like deep learning framework in pure Rust
Volta is a minimal deep learning and automatic differentiation library for Rust, heavily inspired by PyTorch. Built from scratch to expose the internals - DAG linearization, GradFn trait dispatch, and SafeTensors interop - that production frameworks abstract away. It provides a dynamic computation graph, NumPy-style broadcasting, and common neural network primitives.
This project is an educational endeavor to demystify the inner workings of modern autograd engines. It prioritizes correctness, clarity, and a clean API over raw performance, while still providing hooks for hardware acceleration.
- Dynamic Computation Graph: Build and backpropagate through graphs on the fly, just like PyTorch.
- Reverse-Mode Autodiff: Efficient reverse-mode automatic differentiation with topological sorting.
- Rich Tensor Operations: A comprehensive set of unary, binary, reduction, and matrix operations via an ergonomic
TensorOpstrait. - Broadcasting: Full NumPy-style broadcasting support for arithmetic operations.
- Neural Network Layers:
Linear,Conv2d,ConvTranspose2d,MaxPool2d,Embedding,LSTMCell,PixelShuffle,LayerNorm,Flatten,ReLU,GELU,Sigmoid,Tanh,Dropout,BatchNorm1d,BatchNorm2d. - Optimizers:
SGD(momentum + weight decay),Adam(bias-corrected + weight decay), and experimentalMuon. - External Model Loading: Load weights from PyTorch, HuggingFace, and other frameworks via
StateDictMapperwith automatic weight transposition and key remapping. Supports SafeTensors format. - Named Layers: Human-readable state dict keys with
Sequential::builder()pattern for robust serialization. - Multi-dtype Support: Initial support for f16, bf16, f32, f64, i32, i64, u8, and bool tensors.
- IO System: Save and load model weights (state dicts) via
bincodeor SafeTensors format. - BLAS Acceleration (macOS): Optional acceleration for matrix multiplication via Apple's Accelerate framework.
- GPU Acceleration: Experimental WGPU-based GPU support for core tensor operations (elementwise, matmul, reductions, movement ops) with automatic backward pass on GPU.
- Validation-Focused: Includes a robust numerical gradient checker to ensure the correctness of all implemented operations.
This library is functional for training MLPs, CNNs, RNNs, GANs, VAEs, and other architectures on CPU. It features a verified autograd engine and correctly implemented im2col convolutions.
-
✅ What's Working:
- Core Autograd: All operations verified with numerical gradient checking
- Layers: Linear, Conv2d, ConvTranspose2d, MaxPool2d, Embedding, LSTMCell, PixelShuffle, LayerNorm, BatchNorm1d/2d, Dropout
- Optimizers: SGD (with momentum), Adam, Muon
- External Loading: PyTorch/HuggingFace model weights via SafeTensors with automatic transposition
- Named Layers: Robust serialization with human-readable state dict keys
- Loss Functions: MSE, Cross-Entropy, NLL, BCE, KL Divergence
- Examples: MNIST, CIFAR, character LM, VAE, DCGAN, super-resolution, LSTM time series
- GPU Training Pipeline: Fully GPU-accelerated training for Conv2d layers with device-aware constructors and GPU optimizer state storage
- Benchmarking Suite: Comprehensive Criterion benchmarks with 3 categories (tensor_ops, neural_networks, gpu_comparison) and HTML reports
- Enhanced GPU Safety: GPU buffer pooling, command queue throttling, CPU cache invalidation, and early warning system
- Code Quality: All
indexing_slicingclippy errors resolved; ~400+ pedantic lints reduced to ~223 remaining - GPU Convolution: Fully GPU-accelerated Conv2d with Direct, im2col, and iGEMM algorithms
- All algorithms support both forward and backward passes on GPU
- Auto-selection chooses optimal algorithm based on input size and device
- Memory-efficient alternatives (Direct, iGEMM) prevent OOM on large inputs
-
⚠️ What's in Progress:- Performance: Comprehensive benchmarking suite for performance tracking with
just benchcommands - Benchmarks: known issue: crashes under full benchmark suite, suspected resource exhaustion/synchronization barrier, isolated benchmarks pass
- GPU Support: Experimental WGPU-based acceleration via
gpufeature:- ✅ Core ops on GPU: elementwise (unary/binary), matmul, reductions (sum/max/mean), movement ops (permute/expand/pad/shrink/stride)
- ✅ GPU backward pass for autograd with lazy CPU↔GPU transfers
⚠️ Broadcasting preprocessing happens on CPU before GPU dispatch
- Performance: Comprehensive benchmarking suite for performance tracking with
-
❌ What's Missing:
- Production-ready GPU integration, distributed training, learning-rate schedulers, full attention mechanisms (LayerNorm and components are implemented)
Add Volta to your Cargo.toml:
[dependencies]
volta = "0.3.0"For a significant performance boost in matrix multiplication on macOS, enable the accelerate feature:
[dependencies]
volta = { version = "0.3.0", features = ["accelerate"] }For experimental GPU acceleration via WGPU, enable the gpu feature:
[dependencies]
volta = { version = "0.3.0", features = ["gpu"] }Or combine both for maximum performance:
[dependencies]
volta = { version = "0.3.0", features = ["accelerate", "gpu"] }Volta is built around a dynamic computation graph that tracks operations for reverse-mode automatic differentiation.
When operations are performed on tensors requiring gradients, Volta constructs a Directed Acyclic Graph (DAG) of the computation. During the backward pass, this graph is linearized via topological sorting to compute gradients efficiently without deep recursion.
[Input x] (requires_grad=true) [Weights w] (requires_grad=true)
\ /
\ /
v v
[Linear] <--- (GradFn: LinearBackward) --
\
v
[ReLU] <--- (GradFn: ReLUBackward)
|
v
[Loss] <--- (GradFn: MSEBackward)
The core of the autograd engine is the Backward trait. Each differentiable operation implements this trait. During the backward pass, the topological sort yields a linearized sequence of Backward closures. The engine invokes them iteratively, dispatching the gradient of the output to compute the gradients with respect to the inputs. This iterative approach avoids call stack exhaustion on deep, complex graphs.
Volta defaults to HuggingFace's SafeTensors format for model weight serialization instead of Pickle or robust generic serializers like bincode for several reasons:
- Zero-copy loading: SafeTensors allows memory-mapping (mmap) weights directly from disk, bypassing memory allocations during loading.
- Safety: Unlike Pickle, SafeTensors cannot execute arbitrary code.
- Interoperability: It enables Volta to natively read weights from models trained in PyTorch or JAX.
Volta provides multiple backends for computation:
- CPU Backend: Features matrix multiplication accelerated by Apple's Accelerate framework (via the
acceleratefeature) and multi-threaded CPU matrix operations usingmatrixmultiply. - GPU Backend (Experimental): Uses
wgputo dispatch operations to the GPU. Tensors can seamlessly move between devices usingTensor::to_device().
In progress; contributions welcome. CPU matmul baseline vs naive tracked via Criterion.
Here's how to define a simple Multi-Layer Perceptron (MLP) with named layers, train it on synthetic data, and save the model.
use volta::{
Adam, Sequential, TensorOps, io,
nn::{Linear, Module, ReLU},
tensor::{RawTensor, mse_loss},
};
fn main() {
// 1. Define a simple model with named layers: 2 -> 8 -> 1
let model = Sequential::builder()
.add_named("fc1", Box::new(Linear::new(2, 8, true)))
.add_unnamed(Box::new(ReLU))
.add_named("fc2", Box::new(Linear::new(8, 1, true)))
.build();
// 2. Create synthetic data
let x_data = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0];
let x = RawTensor::new(x_data, &[4, 2], false); // Batch size 4, 2 features
let y_data = vec![0.0, 1.0, 1.0, 0.0];
let y = RawTensor::new(y_data, &[4], false); // Flattened targets
// 3. Set up the optimizer
let params = model.parameters();
let mut optimizer = Adam::new(params, 0.1, (0.9, 0.999), 1e-8, 0.0);
// 4. Training loop
println!("Training a simple MLP to learn XOR...");
for epoch in 0..=300 {
optimizer.zero_grad();
let pred = model.forward(&x).reshape(&[4]); //alignment
let loss = mse_loss(&pred, &y);
if epoch % 20 == 0 {
println!("Epoch {}: loss = {:.6}", epoch, loss.borrow().data.first().copied().unwrap_or(f32::NAN));
}
loss.backward();
optimizer.step();
}
// 5. Save and Load State Dict (human-readable keys: "fc1.weight", "fc1.bias", etc.)
let state = model.state_dict();
io::save_state_dict(&state, "model.bin").expect("Failed to save");
// Verify loading
let mut new_model = Sequential::builder()
.add_named("fc1", Box::new(Linear::new(2, 8, true)))
.add_unnamed(Box::new(ReLU))
.add_named("fc2", Box::new(Linear::new(8, 1, true)))
.build();
let loaded_state = io::load_state_dict("model.bin").expect("Failed to load");
new_model.load_state_dict(&loaded_state);
}The following utilizes the current API to define a training-ready CNN.
use volta::{Sequential, Conv2d, MaxPool2d, Flatten, Linear, ReLU, Adam};
use volta::nn::Module;
use volta::TensorOps;
fn main() {
// 1. Define Model
let model = Sequential::new(vec![
// Input: 1x28x28
Box::new(Conv2d::new(1, 6, 5, 1, 2, true)), // Padding 2
Box::new(ReLU),
Box::new(MaxPool2d::new(2, 2, 0)),
// Feature map size here: 6x14x14
Box::new(Flatten::new()),
Box::new(Linear::new(6 * 14 * 14, 10, true)),
]);
// 2. Data & Optimizer
let input = volta::randn(&[4, 1, 28, 28]); // Batch 4
let target = volta::randn(&[4, 10]); // Dummy targets
let params = model.parameters();
let mut optim = Adam::new(params, 1e-3, (0.9, 0.999), 1e-8, 0.0);
// 3. Training Step
optim.zero_grad();
let output = model.forward(&input);
let loss = volta::mse_loss(&output, &target);
loss.backward();
optim.step();
println!("Loss: {:?}", loss.borrow().data.first().copied().unwrap_or(f32::NAN));
}Volta can load weights from PyTorch, HuggingFace, and other frameworks using SafeTensors format with automatic weight mapping and transposition.
use volta::{
Linear, Module, ReLU, Sequential,
io::{load_safetensors, mapping::StateDictMapper},
};
fn main() {
// 1. Build matching architecture with named layers
let mut model = Sequential::builder()
.add_named("fc1", Box::new(Linear::new(784, 128, true)))
.add_unnamed(Box::new(ReLU))
.add_named("fc2", Box::new(Linear::new(128, 10, true)))
.build();
// 2. Load PyTorch weights with automatic transposition
// PyTorch Linear stores weights as [out, in], Volta uses [in, out]
let pytorch_state = load_safetensors("pytorch_model.safetensors")
.expect("Failed to load SafeTensors");
let mapper = StateDictMapper::new()
.transpose("fc1.weight") // [128,784] → [784,128]
.transpose("fc2.weight"); // [10,128] → [128,10]
let volta_state = mapper.map(pytorch_state);
// 3. Load into model
model.load_state_dict(&volta_state);
// 4. Run inference
let input = volta::randn(&[1, 784]);
let output = model.forward(&input);
println!("Output shape: {:?}", output.borrow().shape);
}Weight Mapping Features:
rename(from, to)- Rename individual keysrename_prefix(old, new)- Rename all keys with prefixstrip_prefix(prefix)- Remove prefix from keystranspose(key)- Transpose 2D weight matrices (PyTorch compatibility)transpose_pattern(pattern)- Transpose all matching keysselect_keys(keys)/exclude_keys(keys)- Filter state dict
See examples/load_external_mnist.rs for a complete end-to-end example with validation.
use volta::{Device, TensorOps, randn};
fn main() {
// Create tensors on CPU
let a = randn(&[1024, 1024]);
let b = randn(&[1024, 1024]);
// Move to GPU
let device = Device::gpu().expect("GPU required");
let a_gpu = a.to_device(device.clone());
let b_gpu = b.to_device(device.clone());
// Operations execute on GPU automatically
let c_gpu = a_gpu.matmul(&b_gpu); // GPU matmul
let sum_gpu = c_gpu.sum(); // GPU reduction
// Gradients computed on GPU when possible
sum_gpu.backward();
println!("Gradient shape: {:?}", a_gpu.borrow().grad.as_ref().unwrap().shape());
}The library is designed around a few core concepts:
Tensor: The central data structure, anRc<RefCell<RawTensor>>, which holds data, shape, gradient information, and device location. Supports multiple data types (f32, f16, bf16, f64, i32, i64, u8, bool).TensorOps: A trait implemented forTensorthat provides the ergonomic, user-facing API for all operations (e.g.,tensor.add(&other),tensor.matmul(&weights)).nn::Module: A trait for building neural network layers and composing them into larger models. Providesforward(),parameters(),state_dict(),load_state_dict(), andto_device()methods.Sequential::builder(): Builder pattern for composing layers with named parameters for robust serialization. Supports bothadd_named()for human-readable state dict keys andadd_unnamed()for activation layers.- Optimizers (
Adam,SGD,Muon): Structures that take a list of model parameters and update their weights based on computed gradients duringstep(). Device: Abstraction for CPU/GPU compute. Tensors can be moved between devices withto_device(), and operations automatically dispatch to GPU kernels when available.- External Model Loading:
StateDictMapperprovides transformations (rename, transpose, prefix handling) to load weights from PyTorch, HuggingFace, and other frameworks via SafeTensors format. - Vision Support:
Conv2d,ConvTranspose2d(for GANs/VAEs),MaxPool2d,PixelShuffle(for super-resolution),BatchNorm1d/2d, andDropout. - Sequence Support:
Embeddinglayers for discrete inputs,LSTMCellfor recurrent architectures.
Volta has an extensive test suite that validates the correctness of every operation and its gradient. To run the tests:
cargo test -- --nocaptureTo run tests with BLAS acceleration enabled (on macOS):
cargo test --features accelerate -- --nocaptureTo run tests with GPU support:
cargo test --features gpu -- --nocaptureRun specific test categories:
cargo test core # Core tensor tests
cargo test grad_check # Numerical gradient validation
cargo test broadcasting # Broadcasting rules
cargo test neural # Neural network layers
cargo test optimizer # Optimizer convergenceThe examples/ directory contains complete working examples demonstrating various capabilities:
# Basic examples
cargo run --example readme1 # Simple MLP training
cargo run --example readme2 # LeNet-style CNN
cargo run --example showcase # Feature showcase
# Vision tasks
cargo run --example mnist_cnn # MNIST digit classification
cargo run --example super_resolution # Image upscaling with PixelShuffle
cargo run --example dcgan # Deep Convolutional GAN
# Generative models
cargo run --example vae # Variational Autoencoder
# Sequence models
cargo run --example char_lm # Character-level language model
cargo run --example lstm_time_series # Time series prediction
# External model loading
cargo run --example load_external_mnist # Load PyTorch weights via SafeTensors
# GPU acceleration
cargo run --example gpu --features gpu # GPU tensor operations
cargo run --example gpu_training --features gpu # GPU-accelerated training
# Regression
cargo run --example polynomial_regression # Polynomial curve fittingThe next major steps for Volta are focused on expanding its capabilities to handle more complex models and improving performance.
- Complete GPU Integration: Port remaining neural network layers (Linear, Conv2d) to GPU, optimize GEMM kernels with shared memory tiling.
- Performance Optimization: Implement SIMD for element-wise operations, optimize broadcasting on GPU, kernel fusion for composite operations.
- Learning Rate Schedulers: Cosine annealing, step decay, warmup schedules.
- Conv2d Memory Inefficiency:
im2colimplementation insrc/nn/layers/conv.rsmaterializes the entire matrix in memory. Large batch sizes or high-resolution images will easily OOM even on high-end machines. Somewhat mitigated through algorithm selection in convolutional layer. - GPU Kernel Efficiency: Current GPU matmul uses naive implementation without shared memory tiling. Significant performance gains possible with optimized GEMM kernels.
- Multi-dtype Completeness: While storage supports multiple dtypes (f16, bf16, f64, etc.), most operations still assume f32. Full dtype support requires operation kernels for each type.
- Single-threaded: Uses
Rc<RefCell>instead ofArc<Mutex>, limiting to single-threaded execution on CPU.
Contributions, issues, and feature requests are welcome! Feel free to check the issues page.
This project is licensed under the MIT License - see the LICENSE file for details.