Skip to content

September11111111111/Handwritten-Digit-Recognition

Repository files navigation

基于卷积神经网络的手写数字识别

本项目实现了一个基于 PyTorch 框架的卷积神经网络(CNN)模型,用于 MNIST 数据集的手写数字识别任务,并对比了逻辑回归、决策树、支持向量机三种传统机器学习模型的分类性能。

项目结构

├── cnn_mnist_pytorch.py         # 主CNN模型训练与测试代码
├── traditional_models.py        # 三种传统方法(LR, DT, SVM)对比实验
├── cnn_lr_optimizer_test.py     # 学习率与优化器对CNN影响的实验脚本
├── modelpara.pth                # 已训练好的CNN模型权重文件
├── README.md                    # 本说明文档
└── requirements.txt             # 项目所需依赖库列表

项目功能

  • 基于 PyTorch 实现的 CNN 手写数字识别模型;
  • 实现并对比了逻辑回归、决策树、支持向量机三种方法;
  • 实验学习率(lr)与优化器(Adam/SGD/RMSProp)对CNN性能的影响;
  • 使用 MNIST 数据集,输出准确率、查准率、查全率和 F1 值等指标。

环境依赖

请使用 Python 3.8+ 环境,并通过以下命令安装依赖:

pip install -r requirements.txt

典型依赖库包括但不限于:

  • torch
  • torchvision
  • scikit-learn
  • matplotlib
  • numpy

快速运行

1. 训练 CNN 模型

python cnn_mnist_pytorch.py

模型将自动训练,并输出测试集准确率。训练完成后将保存模型权重文件 modelpara.pth

2. 对比传统分类器性能(逻辑回归、决策树、SVM)

python traditional_models.py

3. 超参数影响实验(学习率 + 优化器)(其实是懒得改超参写的脚本)

python cnn_lr_optimizer_test.py

实验结果示例

模型 准确率 (%)
CNN 99.10
SVM 96.95
Logistic 90.48
DecisionTree 88.18

About

基于卷积神经网络的手写数字识别

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages