7步搞懂手写数字识别Mnist

大家好啊,我是董董灿。

图像识别有很多入门项目,其中Mnist 手写数字识别绝对是最受欢迎的。

该项目以数据集小、神经网络简单、任务简单为优势,并且集合了CNN网络中该有的东西,可谓麻雀虽小,五脏俱全。

非常适合新手上手学习。

本文以代码走读的形式,带你一览该项目的每一处细节。

文章末尾附代码下载链接,不用GPU, 你也可以从头训练一个神经网络出来。

什么是手写数字识别

简答来说,就是搭建了一个卷积神经网络,可以完成手写数字的识别。

我用笔在纸上写了个6,神经网络就能认识这是个6,我写了个8,它就识别出来这是个8,就这么简单。

之所以说该任务简单,是因为它的标签只有 0-9 这 10 种分类,相比于 resnet 等网络在 ImageNet 上 1000 个分类,确实小很多。

虽然简单,但背后的原理却一点都不少,典型的CNN训练和算法无一缺席。

与该项目一起出名的,便是大名鼎鼎的 MNIST(Mathematical Numbers In Text) 数据集。

该数据集中包含了 60,000 个训练图像和 10,000 个测试图像,图像都是各种手写的数字,基本都是长这样的。

7步精读代码

在简单了解了项目背景后,我以代码走读的形式,一点点介绍该神经网络。

第一步:导入必要的库

# 导入NumPy数学工具箱
import numpy as np 
# 导入Pandas数据处理工具箱
import pandas as pd
# 从 Keras中导入 mnist数据集
from keras.datasets import mnist

keras 是一个开源的人工神经网络库,里面有很多经典的神经网络和数据集,要用的 mnist 数据集就在其中。

第二步:加载数据集

(x_train, y_train), (x_test, y_test)
=  mnist.load_data() 

这条命令利用 keras 中自带的 mnist 模块,加载数据集(load_data)进来,分别赋值给四个变量。

其中:x_train 保存用来训练的图像,y_train 是与之对应的标签。假设图像中的数字是1,那么标签就是1。

x_test 和 y_test 分别为用来验证的图像和标签,也就是验证集。训练完神经网络后,可以使用验证集中的数据进行验证。

第三步:数据预处理

其中一个预处理内容是改变数据集的 shape,使其满足模型的要求。

 # 导入keras.utils工具箱的类别转换工具
from tensorflow.keras.utils import to_categorical# 给标签增加维度,使其满足模型的需要# 原始标签,比如训练集标签的维度信息是[60000, 28, 28, 1]
X_train = X_train_image.reshape(60000,28,28,1)
X_test = X_test_image.reshape(10000,28,28,1)# 特征转换为one-hot编码
y_train = to_categorical(y_train_lable, 10)
y_test = to_categorical(y_test_lable, 10)

这个数据集中的共 60000 张训练图像,10000 张验证图像,每张图像的长宽均为 28 个像素,通道数为 1。

那么对于训练集 x_train 而言,将其形状变为 NHWC = [60000, 28, 28, 1], 验证集类似。

to_categorical 的作用是将样本标签转为 one-hot 编码,而 one-hot 编码的作用是可以对于类别更好的计算概率或得分。

one-hot

之所以用 one-hot 编码,是因为对于输出 0-9 这10个标签而言,每个标签的地位应该是相等的,并不存在标签数字 2 大于数字 1 的情况。

但如果我们直接利用标签的原始值(0-9)进行最终结果的计算,就会出现标签2 大于标签 1的情况。

因此,在大部分情况下,都需要将标签转换为 one-hot 编码,也就独热编码,这样标签之间便没有任何大小而言。

这个例子中,数字 0-9 转换为的独热编码为:

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]

每一行的向量代表一个标签。

假设 [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.] 代表 0 而 [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.] 代表1,可以看到这两者之间是正交独立的,不存在谁比谁大的问题。

第四步:创建神经网络。

# 从 keras 中导入模型
from keras import models 
# 从 keras.layers 中导入神经网络需要的计算层
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
# 构建一个最基础的连续的模型,所谓连续,就是一层接着一层
model = models.Sequential()
# 第一层为一个卷积,卷积核大小为(3,3), 输出通道32,使用 relu 作为激活函数
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28,28,1)))
# 第二层为一个最大池化层,池化核为(2,2)
# 最大池化的作用,是取出池化核(2,2)范围内最大的像素点代表该区域
# 可减少数据量,降低运算量。
model.add(MaxPooling2D(pool_size=(2, 2)))
# 又经过一个(3,3)的卷积,输出通道变为64,也就是提取了64个特征。
# 同样为 relu 激活函数
model.add(Conv2D(64, (3, 3), activation='relu'))
# 上面通道数增大,运算量增大,此处再加一个最大池化,降低运算
model.add(MaxPooling2D(pool_size=(2, 2)))
# dropout 随机设置一部分神经元的权值为零,在训练时用于防止过拟合
# 这里设置25%的神经元权值为零
model.add(Dropout(0.25)) 
# 将结果展平成1维的向量
model.add(Flatten())
# 增加一个全连接层,用来进一步特征融合
model.add(Dense(128, activation='relu'))
# 再设置一个dropout层,将50%的神经元权值为零,防止过拟合
# 由于一般的神经元处于关闭状态,这样也可以加速训练
model.add(Dropout(0.5)) 
# 最后添加一个全连接+softmax激活,输出10个分类,分别对应0-9 这10个数字
model.add(Dense(10, activation='softmax'))

上面每一行代码都加了注释,说明每一行的作用,短短几行,便是这个手写数字识别神经网络的全部了。

第五步:训练

# 编译上述构建好的神经网络模型
# 指定优化器为 rmsprop
# 制定损失函数为交叉熵损失
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
# 开始训练              
model.fit(X_train, y_train, # 指定训练特征集和训练标签集validation_split = 0.3, # 部分训练集数据拆分成验证集epochs=5, # 训练轮次为5轮batch_size=128) # 以128为批量进行训练

Epoch 5/5
329/329 [==============================] - 15s 46ms/step - loss: 0.1054 - accuracy: 0.9718 - val_loss: 0.0681 - val_accuracy: 0.9826
训练结果如上,可以看到最后的训练精度达到了98.26%,还是挺高的。

第6步:验证集验证

# 在测试集上进行模型评估
score = model.evaluate(X_test, y_test) 
print('测试集预测准确率:', score[1]) # 打印测试集上的预测准确率

313/313 [==============================] - 1s 4ms/step - loss: 0.0662 - accuracy: 0.9815 测试集预测准确率: 0.9815000295639038

可以看到在验证集上也能有98%的准确率。

第7步:验证一张图片

# 预测验证集第一个数据
pred = model.predict(X_test[0].reshape(1, 28, 28, 1)) 
# 把one-hot码转换为数字
print(pred[0],"转换一下格式得到:",pred.argmax())# 导入绘图工具包
import matplotlib.pyplot as plt
# 输出这个图片
plt.imshow(X_test[0].reshape(28, 28),cmap='Greys')

以验证集中的第一张图片为例来进行验证。

1/1 [==============================] - 0s 17ms/step
[4.2905590e-15 2.6790809e-11 2.8249305e-09 2.3393848e-11 7.1304548e-14
1.8217797e-18 5.7493907e-19 1.0000000e+00 8.0317367e-15 4.6352322e-10]

转换一下格式得到:7

得到的数字是7,将该图片显示出来,确实是7。说明训练的模型确实达到了识别数字的水平。

总结

手写数字识别项目比较简单,仅仅两个卷积层,整体运算量不大,就目前计算机的配置,即使笔记本基本上都可以完成该神经网络的训练和验证。

如果你感兴趣,关注公众号《董董灿是个攻城狮》后台回复【mnist】获取源码,实操起来吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/250934.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

二开项目权限应用全流程-按钮级控制

二开项目权限应用全流程-按钮级控制 员工A和员工B都可以访问同一个页面(以员工管理为例),但是员工A可以导出excel,员工B就不可以导出excel(看不到按钮) 思路 用户登陆成功后,用户可以访问的按钮级别权限保存在point…

VISIO使用技巧汇总

0.连接线拐弯或者连接不合适 0-0.Goal ​​​​​​​ 0-1. Automatic connection 0-3.Resolvent 0-3-0.ALTF9选项,取消粘附位置调整 0-3-1.选中线段-选中中心点-shift增加直角调整合适位置

Microsoft Visio 直线连接线

Microsoft Visio 直线连接线 1. 连接线 2. 直线连接线 3. 直线连接线图 References https://yongqiang.blog.csdn.net/

visio画太极图

步骤一 添加两个圆,且大圆的半径是小圆的2倍。 步骤二 往小圆添加一条直线作为直径 步骤三 选中小圆和直径,依次点击开发工具–操作–连接,然后选中连接后的小圆,再依次点击开发工具–操作–修建,可以分离出如下所示的两个…

visio 2007 画直线和矩形

visio 2007 画直线和矩形 1.问题描述 在一些图形中如果直接用连接线,会直接连到一些不理想的位置,而2007中不像2013及其以后那些版本中,有侧边栏能够直接画直线。 2.解决方式 直接选择工具栏中的红圈中的图标 能够生成红圈中的工具栏 然…

Visio对mysql怎么画er图_怎么用Visio画ER图

展开全部 画法如下: 1、由于Visio 2003默认的绘图模板并没有32313133353236313431303231363533e4b893e5b19e31333339653661E-R图这一项,但是画E-R图必须的基本图形Visio 2003还是有的,所以就得先把必要的图形添加到“我的模板”。以添加椭圆和…

Visio2010中设置线为直线

Visio2010中设置线为直线 在Visio2010中默认的线不是直线而是曲线,在画图中需要使用直线时要进行设置,下面介绍Visio2010中设置直线的方法。 1、打开Visio2010,然后点击设计: 2、点击调整大小下面的三角: 3、进入页面…

visio绘制流程图连接线总拐弯

描述 如图所示绘制流程图的连接线总拐弯 很让我强迫症发作 可以看到垂直的连接线总是会自动拐个弯 相关技巧 有说连接线中间点可以控制和增加中间点 或者按住shift 进行调整 这个还没研究明白咋操作不过没解决本质问题 此外还可以右键修改连接线属性 还可以在设计中进行调…

visio插入箭头_visio流程图中画箭头

visio流程图中画箭头 随着社会和经济的发展,电脑visio 2019软件已经成为我们生活中必不可少的一部分。visio 2019软件常常被我们使用于流程图的制作,很多第一次接触的朋友们不知道怎么在visio 2019软件制作流程图,接下来就让小编来教你们吧。 具体如下: 1. 第一步,打开电脑…

visio绘图小技巧

1.如何在图框的任意位置添加点? 先选中x点指令,再按住ctrl键,即可在任意位置画点 2.如何画出锯齿形线段? visio里面好像没有现成的锯齿形线段,所以可以利用直线反复折画,但是这里有个小技巧,就…

Visio简单画图使用方法

Visio使用方法 相信有很多初学者跟我一样,只会使用Word进行简单的画图。本章主要讲述如何使用Visio来画图(版本为2010) 一、系统流程图、数据流程图、ER图画法 1.打开Visio软件,创建简单模板 2.根据需求点击左侧"基本流程…

设计模式之~外观模式

定义: 为子系统中的一组接口提供一个一致的界面,此模式定义了一个高层接口,这个接口使得这一子系统更加容易使用。 结构图: 区分中介模式: 门面模式对外提供一个接口 中介模式对内提供一个接口 优点: 松耦…

【xv6操作系统】安装、运行与调试

一、构建、装入过程 1.编写“启动代码主体代码”(在下载的xv6的原始代码上进行修改) 2.源代码进行编译、链接生成系统镜像(elf格式的目标文件) 3.将系统镜像保存起来(如保存到磁盘、flash或者网络服务器上&#xff…

spring入门(面试题)

Spring框架的核心:IoC容器和AOP模块。通过IoC容器管理POJO对象以及他们之间的耦合关系;通过AOP以动态非侵入的方式增强服务。 IoC让相互协作的组件保持松散的耦合,而AOP编程允许你把遍布于应用各层的功能分离出来形成可重用的功能组件。 IO…

spring 入门

Spring 是什么(1) Spring 是一个开源框架.  Spring 为简化企业级应用开发而生. 使用 Spring 可以使简单的 JavaBean 实现以前只有 EJB 才能实现的功能.  Spring 是一个 IOC(DI) 和 AOP 容器框架. Spring 是什么(2) 具体描述 Spring:  轻量级:Spring 是非侵入性…

spring入门--spring入门案例

spring是一个框架,这个框架可以干很多很多的事情。感觉特别吊。但是,对于初学者来说,很难理解spring到底是干什么的。我刚开始的时候也不懂,后来就跟着敲,在后来虽然懂了,但是依然说不明白它到底是干啥的。…

Spring入门示例

开发环境 Spring 4.3.0Myeclipse2015JDK1.8 准备阶段: 1、新建一Spring01项目,然后新建一个lib文件。将下面的添加到lib文件中 2、将lib文件所有的包导入项目 开发步骤: 1、新建一个Hello.java的类 1 package com.proc.bean;2 3 public class…

Spring入门详解

准备开始看看Spring源码, 所以把Spring复习一遍,做下笔记.. 一. Spring的4种关键策略: 1. 基于POJO的轻量级和最小入侵编程(基本不出现在业务逻辑中, 不用强制实现接口之类的); 2. 通过依赖注入和面向接口来实现松耦…

1.Spring入门

一、Spring入门 1.Spring 入门 Spring 整体框架: 什么是IOC? IOC:Inversion of Control(控制反转),即将对象的创建权反转给(交给)spring。 spring包含的文件: docs&am…

spring入门例子

Hello World入门 2.5.1、准备开发环境和运行环境: ☆开发工具:eclipse ☆运行环境:tomcat6.0.20 ☆工程:动态web工程(springmvc-chapter2) ☆spring框架下载: spring-framework-3.1.1.RELEASE-w…