002-基于Pytorch的手写汉字数字分类

本节将介绍一种

2.1 准备

2.1.1 数据集

(1)MNIST

只要学习过深度学习相关理论的人,都一定听说过名字叫做LeNet-5模型,它是深度学习三巨头只有Yann Lecun在1998年提出的一个CNN模型(很多人认为这是第一个具有实际应用价值的CNN模型)。在当年使用该模型可以很好地完成手写体数字的识别,而该模型所处理的手写体数字数据库称为MNIST。

MNIST全称是:Mixed National Institute of Standards and Technology databas,它包含70000张手写数字的灰度图片,每一张图片包含 28 X 28 个像素点。数据集被分为两部分,其中训练(mnist.train)集包括60000样本,测试集(mnist.test)包含10000样本。训练集又进一步封你为 55000 个样本用于训练,5000样本用于验证。下图是MNIST样本实例图。

MNIST数据集虽然经典,但也有问题。最主要的问题是,它太简单了!相对于现在动辄上百万个参数的“大”模型,MNIST数据集要小很多,且只是简单的十类问题,因此导致现有的模型在MNIST上的分类精度都超过了95%。为了更直观地观察不同算法间的性能差异,需要用一个更复杂一点的数据集,这时Fashion-MNIST出现了。

(2)Fashion-MNIST

FashionMNIST是一个替代MNIST的图像数据集。 它是由一家德国科技公司(Zalando)整理提供。FashionMNIST 的大小、格式和训练集/测试集划分与原始的 MNIST 完全一致。60000/10000 的训练测试数据划分,28x28 的灰度图片。因此,能跑MNIST数据集的代码,只需稍加改动,就可以跑新的数据集。两个数据集的不同之处主要有两点,一是虽然两者都是以灰度图像呈现的,但MNIST呈现的是数字,背景设为0,前景设为1,FashionMNIST则是真正意义的灰度数据集。二是两者内容不同,前者被分类的是手写体数字,后者则是十类衣物服饰(分别是:T恤、裤子、套头衫、连衣裙、大衣、凉鞋、衬衫、运动鞋、包、短靴),其内容的复杂程度远高于手写体数字。下图是FashionMNIST的一个图示。

网上有很多基于FashionMNIST数据集的实例,在此就不再重复介绍。

本节实例选用的是中国版的MNIST,由英国纽卡斯尔大学整理并提供,我们不妨将其称为CHN-MNIST数据集。

(3)CHN-MNIST

该数据集共由100人书写,每人重复书写10遍,因此数据集样本数为1000组,每组包括15个汉字的数字,即“零、一、二、三、四、五、六、七、八、九、十、百、千、万、亿”,总样本数为15000。图像的分辨率为300*300。

2.1.2  模型

对于这样一个简单的分类任务,不需要使用太复杂的网络,前面提到的LeNet-5足能胜任。

对于LeNet-5网络模型的介绍,网上一搜一大把,在此不再赘述,只贴出该模型的示意图,供大家参考。

2.2 代码解析

下面将结合代码,一部分一部分的介绍具体的过程。

(1)载入必要的扩展库

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

由于是第一个例程,我们对所使用的扩展库详细加以介绍:

  • matplotlib库:用于绘图
  • numpy库:用于数值计算
  • pandas库:用于数据分析
  • torch库:提供Pytorch支持
  • PIL库:用于图像绘制
  • tqdm库:Python提供的进度条空间库

(2)设置参数

这一部分完成的是设置一些与模型训练有关的超参数。如下面代码所示:

batch_size = 32  # 批次大小
lr = 0.003  # 学习率
epochs = 10  # 迭代轮数
save_path = './best_model.pkl'  # 模型保存路径
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 设备

各个参数的功能见注释,至于各个参数数值大小对最终结果的影响,将放在后续的章节介绍。其中最后一行是自动检测是否安装了cuda,如果是,则启动gpu加速。

(3)加载数据集

这一部分完成的是设置一些与模型训练有关的超参数。如下面代码所示:

class CustomDataset(Dataset):def __init__(self, k, l, csv_file='./chinese_mnist.csv'):self.df = pd.read_csv(csv_file)self.k = {'九': int(9), '十': int(10), '百': int(11), '千': int(12), '万': int(13), '亿': int(14), '零': int(0),'一': int(1), '二': int(2), '三': int(3), '四': int(4), '五': int(5), '六': int(6), '七': int(7),'八': int(8)}self.target = 'character'self.features = ['suite_id', 'sample_id', 'code', ]self.labels = np.asarray(self.df.iloc[:, 4])self.y = df[self.target]self.X = df.drop(self.target, axis=1)def __getitem__(self, idx):single_image_label = self.labels[idx]class_id = self.k[single_image_label]img = Image.open(f"./data/data/input_{self.X.iloc[idx, 0]}_{self.X.iloc[idx, 1]}_{self.X.iloc[idx, 2]}.jpg")img = np.array(img)return img, class_iddef __len__(self):return len(self.X)

还需要对数据集进行一下预处理,便于后面的训练过程g

# 1.构建索引到汉字的映射字典
num2char = {int(9): '九', int(10): '十', int(11): '百',int(12): '千', int(13): '万', int(14): '亿',int(0): '零', int(1): '一', int(2): '二',int(3): '三', int(4): '四', int(5): '五',int(6): '六', int(7): '七', int(8): '八'}# 2.读取csv处理文件
df = pd.read_csv('./chinese_mnist.csv', sep=',')# 3.处理数据
train_df = df.groupby('value').apply(lambda x: x.sample(700, random_state=42)).reset_index(drop=True)
x_train, y_train = train_df.iloc[:, :-2], train_df.iloc[:, -2]test_df = df.groupby('value').apply(lambda x: x.sample(300, random_state=42)).reset_index(drop=True)
x_test, y_test = test_df.iloc[:, :-2], test_df.iloc[:, -2]

(未完,待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2905366.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

ARMv9新特性:虚拟内存系统架构 (VMSA) 的增强功能

快速链接: 【精选】ARMv8/ARMv9架构入门到精通-[目录] 👈👈👈 权限索引 2022 ARM引入了一种新的控制内存权限方法。 不再是直接在转换表条目 (TTE) 中编码权限,而是使用 TTE 中的字段来索引寄存器中指定的权限数组。这种间接提供…

hadoop-3.1.1分布式搭建与常用命令

一、准备工作 1.首先需要三台虚拟机: master 、 node1 、 node2 2.时间同步 ntpdate ntp.aliyun.com 3.调整时区 cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime 4.jdk1.8 java -version 5.修改主机名 三台分别执行 vim /etc/hostname 并将内容指定为…

大模型时代下的“金融业生物识别安全挑战”机遇

作者:中关村科金AI安全攻防实验室 冯月 金融行业正在面临着前所未有的安全挑战,人脸安全事件频发,国家高度重视并提出警告,全行业每年黑产欺诈涉及资金额超过1100亿元。冰山上是安全事件,冰山下隐藏的是“裸奔”的技术…

主流好用的 Markdown 编辑器介绍

在当今程序员的日常工作中,Markdown 已经成为了一种常用的文本标记语言,它简洁、易读、易写,被广泛应用于写作、文档编写、博客撰写等场景。为了更高效地编辑和管理 Markdown 格式的文档,选择一款功能强大、易用的 Markdown 编辑器…

I.MX6ULL_Linux_驱动篇(55)linux 网络驱动

网络驱动是 linux 里面驱动三巨头之一, linux 下的网络功能非常强大,嵌入式 linux 中也常常用到网络功能。前面我们已经讲过了字符设备驱动和块设备驱动,本章我们就来学习一下linux 里面的网络设备驱动。 嵌入式网络简介 网络硬件接口 首先…

联想 lenovoTab 拯救者平板 Y700 二代_TB320FC原厂ZUI_15.0.677 firmware 线刷包9008固件ROM root方法

联想 lenovoTab 拯救者平板 Y700 二代_TB320FC原厂ZUI_15.0.677 firmware 线刷包9008固件ROM root方法 ro.vendor.config.lgsi.market_name拯救者平板 Y700 ro.vendor.config.lgsi.en.market_nameLegion Tab Y700 #ro.vendor.config.lgsi.short_market_name联想平板 ZUI T # B…

用Kimichat快速识别出图片中的表格保存到Excel

如果有一张图片格式的表格,想要快速复制到Excel表格中,那么一般要借助于OCR工具。之前试过不少在线OCR工具,识别效果差强人意。其实,kimichat就可以非常好的完成这个任务。 下面是一张研报中的表格,只能以图片形式保存…

Linux 系统 CentOS7 上搭建 Hadoop HDFS集群详细步骤

集群搭建 整体思路:先在一个节点上安装、配置,然后再克隆出多个节点,修改 IP ,免密,主机名等 提前规划: 需要三个节点,主机名分别命名:node1、node2、node3 在下面对 node1 配置时,先假设 node2 和 node3 是存在的 **注意:**整个搭建过程,除了1和2 步,其他操作都使…

网络基础(day1)

计算机网络 计算机网络:实现计算机数据的传输。 数据通过网络转发,由应用程序产生!!! 工作组网络电脑直接拿线连接。局域网通过集线器和交换机这样的设备串联起来的网络。城域网城市网络包括了多个小区,多个…

Tensorflow2.0笔记 - metrics做损失和准确度信息度量

本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下: Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博…

Android笔记(三十):PorterDuffXfermode实现旋转进度View

背景 核心原理是使用PorterDuffXfermode Path来绘制进度,并实现圆角 效果图 Android笔记(三十)效果演示 进度条绘制步骤 将ImageView矩形七个点的坐标存储起来(configNodes) 他们对应着7个不同的刻度,每个刻度的值 i * &#…

Stable Diffusion WebUI 生成参数:脚本(Script)——提示词矩阵、从文本框或文件载入提示词、X/Y/Z图表

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 大家好,我是水滴~~ 在本篇文章中,我们将深入探讨 Stable Diffusion WebUI 的另一个引人注目的生成参数——脚本(Script)。我们将逐一细说提示词矩阵、从文本框或文件导入提示词,…

2.4 比较检验 机器学习

目录 常见比较检验方法 总述 2.4.1 假设检验 2.4.2 交叉验证T检验 2.4.3 McNemar 检验 接我们的上一篇《性能度量》,那么我们在某种度量下取得评估结果后,是否可以直接比较以评判优劣呢?实际上是不可以的。因为我们第一,测试…

iOS UIFont-实现三方字体的下载和使用

UIFont 系列传送门 第一弹加载本地字体:iOS UIFont-新增第三方字体 第二弹加载线上字体:iOS UIFont-实现三方字体的下载和使用 前言 在上一章我们完成啦如何加载使用本地的字体。如果我们有很多的字体可供用户选择,我们当然可以全部使用本地字体加载方式,可是这样就增加了…

【Golang入门教程】Go语言变量的初始化

文章目录 强烈推荐引言举例多个变量同时赋值总结强烈推荐专栏集锦写在最后 强烈推荐 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站:人工智能 推荐一个个人工作,日常中比较常…

使用 Yoda 和 ClickHouse 进行实时欺诈检测

背景 Instacart 是北美领先的在线杂货公司,拥有数百万活跃的客户和购物者。在其平台上打击欺诈和滥用行为不仅对于维护一个值得信赖和安全的环境至关重要,也对保持Instacart的财务健康至关重要。在这篇文章中,将介绍了一个欺诈平台——Yoda,解释了为什么我们选择ClickHous…

MyEclipse打开文件跳转到notepad打开问题

问题描述 windows系统打开README.md文件,每次都需要右键选择notepad打开,感觉很麻烦,然后就把README.md文件打开方式默认选择了notepad,这样每次双击就能打开,感觉很方便。 然后某天使用MyEclipse时,双击RE…

基于SpringBoot+VUE的后台资金管理系统

采用技术 基于SpringBootVUE的后台资金管理系统的设计与实现~ 开发语言:Java 数据库:MySQL 技术:SpringBootMyBatis 工具:IDEA/Ecilpse、Navicat、Maven 页面展示效果 员工首页 采购申请 商品添加 数据查询 管理员首页 …

数字化驱动乡村发展:数字乡村助力农村繁荣

随着信息技术的迅猛发展,数字化已成为驱动社会进步的重要引擎。在乡村发展的道路上,数字乡村以其独特的魅力,正在成为推动农村繁荣的重要力量。数字化技术的应用不仅为乡村带来了便捷和高效,更为乡村的经济、社会、文化等多个方面…

mysql 常见运算符

学习了mysql数据类型,接下来学习mysql常见运算符。 2,常见运算符介绍 运算符连接表达式中各个操作数,其作用是用来指明对操作数所进行的运算。运用运算符 可以更加灵活地使用表中的数据,常见的运算符类型有:算…