Warning
SPSN has evolved into the ParaLIF (Parallelizable Leaky-Integrate-and-Fire) neuron. ParaLIF allows more stochastic and deterministic spiking functions. A recurrent version is also available. Visit https://github.com/NECOTIS/Parallelizable-Leaky-Integrate-and-Fire-Neuron
This repository contains code for simulating the proposed SPSN to accelerate training of spiking neural networks (SNN). The SPSN is compared to Leaky Integrate and Fire (LIF) neuron on the Spiking Heidelberg Digits (SHD) dataset. This repository consists of a few key components:
-
datasets.py: This module provides a simple interface for loading and accessing training and test datasets. -
network.py: This module contains the implementation of the neural network itself, including code for training and evaluating the network. -
run.py: This is the main entry point for running the simulation. It provides a simple command-line interface for specifying various options. -
datasetsdirectory: This directory contains training and test datasets as hdf5 files. The SHD dataset needs to be downloaded to this directory from https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/ -
neuronsdirectory: This directory contains implementations for the two neurons types, extending the base class inbase.py. The available models are:lif.py: The Leaky Integrate-and-Fire modelspsn.py: The Stochastic Parallelizable Spiking Neuron model. It can be simulated with the Sigmoid-Bernouilli firing mode (SPSP-SB) or with the Gumbel Softmax firing mode (SPSN-GS).
-
outputsdirectory: This directory contains outputs generated by the simulation.
The run.py script can be run using various arguments. The following are available:
--seed: Random seed for reproducibility.--dataset: The dataset to use for training, currently onlyheidelbergis supported.--neuron: The neuron model to use for training, options includeLIF,SPSN-SB,SPSN-GS, andNon-Spiking. TheNon-Spikingneuron is a traditional neuron followed by a ReLu activation.--nb_epochs: The number of training epochs.--tau_mem: The neuron membrane time constant.--tau_syn: The neuron synaptic current time constant.--batch_size: The batch size for training.--hidden_size: The number of neurons in the hidden layer.--nb_layers: The number of hidden layers.--reg_thr: The spiking frequency regularization threshold.--loss_mode: The mode for computing the loss, options includelast,max, andmean.--data_augmentation: Whether to use data augmentation during training, options includeTrueandFalse.--h_shift: The random shift factor for data augmentation.--scale: The random scale factor for data augmentation.--dir: The directory to save the results.--save_model: Whether to save the trained model, options includeTrueandFalse.
To run the code in the basic mode, the following commands can be used.
python run.py --seed 0 --neuron 'LIF'
python run.py --seed 0 --neuron 'SPSN-SB'
python run.py --seed 0 --neuron 'SPSN-GS'
python run.py --seed 0 --neuron 'Non-Spiking'To add data augmentation when training, the following commands can be used.
python run.py --seed 0 --neuron 'LIF' --data_augmentation True
python run.py --seed 0 --neuron 'SPSN-SB' --data_augmentation True
python run.py --seed 0 --neuron 'SPSN-GS' --data_augmentation True
python run.py --seed 0 --neuron 'Non-Spiking' --data_augmentation TrueTo reduce spiking frequency for SPSN the regularization can be used by the following commands:
python run.py --seed 0 --neuron 'SPSN-SB' --data_augmentation True --reg_thr 0.4
python run.py --seed 0 --neuron 'SPSN-GS' --data_augmentation True --reg_thr 0.1The results achieved for the commands listed above are summurized in the following tables :
- Classification accuracy for the test set :
| Neuron | Basic | Data augmentation | Data augmentation + Regularization |
|---|---|---|---|
| LIF | 71.37% | 83.03% | - |
| SPSN-SB | 77.16 % | 86.08 % | 89.70 % |
| SPSN-GS | 75.66 % | 86.08 % | 89.39 % |
| Non-Spiking | 71.82 % | 66.07 % | - |
- 1 epoch training duration :
| Neuron | Basic | Data augmentation | Data augmentation + Regularization |
|---|---|---|---|
| LIF | 252.2 s | 261.7 s | - |
| SPSN-SB | 5.5 s | 10.7 s | 10.7 s |
| SPSN-GS | 6.8 s | 12.2 s | 12.2 s |
| Non-Spiking | 1.5 s | 6.7 s | - |
The required librairies to run the code are :
- h5py
- numpy
- torch
- torchvision
- tqdm