This is a knowledge distillation model based on pytorch-cifar repository of kuangliu
Based on pytorch-cifar repository's classic model collection, conduct knowledge distillation learning on classic models, analyze the CIFAR10 classification learning effect of the model, and evaluate the memory advantage of knowledge distillation
- Python 3.6+
- PyTorch 1.0+
# Start training with:
python main.py
# You can manually resume the training with:
python main.py --resume --lr=0.01
| Model | Acc. |
|---|---|
| VGG16 | 92.64% |
| ResNet18 | 93.02% |
| ResNet50 | 93.62% |
| ResNet101 | 93.75% |
| RegNetX_200MF | 94.24% |
| RegNetY_400MF | 94.29% |
| MobileNetV2 | 94.43% |
| ResNeXt29(32x4d) | 94.73% |
| ResNeXt29(2x64d) | 94.82% |
| SimpleDLA | 94.89% |
| DenseNet121 | 95.04% |
| PreActResNet18 | 95.11% |
| DPN92 | 95.16% |
| DLA | 95.47% |
This project supports DLA → MobileNetV2 knowledge distillation training, and the script file is distill_dla_mobilenetv2.py。
python distill_dla_mobilenetv2.py--lrLearning Rate (Default 0.05)--epochsNumber of training rounds (default 200)--alphaHard loss weight (default 0.7)--tempDistillation temperature (default 5.0)--batch_sizeBatch size (default 128)--resumeResume training from the latest checkpoint
for example:
python distill_dla_mobilenetv2.py --lr 0.01 --epochs 100 --alpha 0.5 --temp 4.0 --batch_size 64- Auto-load
./checkpoint/dla.pthas the teacher model weight. - The best student model is saved to
./checkpoint/mobilenetv2_distilled.pthduring training. - The latest checkpoints will be saved to
./checkpoint/mobilenetv2_latest.pthin each round, and the--resumeparameter can be used to resume the training.
After the training is over, the terminal outputs the best accuracy. Model weights can be found under the 'checkpoint' folder.
If you need to customize the data path or model structure, please refer to the parameter settings section of the script to modify it.
本项目基于 kuangliu 的 pytorch-cifar 仓库实现知识蒸馏模型。通过该仓库的经典模型集合,对经典模型进行知识蒸馏学习,分析模型在 CIFAR10 分类任务上的学习效果,并评估知识蒸馏带来的内存优势。
- Python 3.6+
- PyTorch 1.0+
# Start training with:
python main.py
# You can manually resume the training with:
python main.py --resume --lr=0.01
| Model | Acc. |
|---|---|
| VGG16 | 92.64% |
| ResNet18 | 93.02% |
| ResNet50 | 93.62% |
| ResNet101 | 93.75% |
| RegNetX_200MF | 94.24% |
| RegNetY_400MF | 94.29% |
| MobileNetV2 | 94.43% |
| ResNeXt29(32x4d) | 94.73% |
| ResNeXt29(2x64d) | 94.82% |
| SimpleDLA | 94.89% |
| DenseNet121 | 95.04% |
| PreActResNet18 | 95.11% |
| DPN92 | 95.16% |
| DLA | 95.47% |
本项目支持 DLA → MobileNetV2 的知识蒸馏训练,脚本文件为 distill_dla_mobilenetv2.py。
python distill_dla_mobilenetv2.py--lr学习率(默认 0.05)--epochs训练轮数(默认 200)--alpha硬损失权重(默认 0.7)--temp蒸馏温度(默认 5.0)--batch_size批大小(默认 128)--resume从最新检查点恢复训练
使用示例:
python distill_dla_mobilenetv2.py --lr 0.01 --epochs 100 --alpha 0.5 --temp 4.0 --batch_size 641.自动加载 ./checkpoint/dla.pth 作为教师模型权重
2.训练过程中最佳学生模型将保存至 ./checkpoint/mobilenetv2_distilled.pth
3.每轮训练的最新检查点将保存至 ./checkpoint/mobilenetv2_latest.pth,可使用 --resume 参数恢复训练
训练结束后,终端将输出最佳准确率。所有模型权重可在 checkpoint 文件夹中找到。
如需自定义数据路径或模型结构,请参考脚本中的参数设置部分进行修改。