【深度学习笔记】3_5 图像分类数据集fashion-mnist

注:本文为《动手学深度学习》开源内容,仅为个人学习记录,无抄袭搬运意图

3.5 图像分类数据集(Fashion-MNIST)

在介绍softmax回归的实现前我们先引入一个多类图像分类数据集。它将在后面的章节中被多次使用,以方便我们观察比较算法之间在模型精度和计算效率上的区别。图像分类数据集中最常用的是手写数字识别数据集MNIST[1]。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的数据集Fashion-MNIST[2](这个数据集也比较小,只有几十M,没有GPU的电脑也能吃得消)。

本节我们将使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。

3.5.1 获取数据集

首先导入本节需要的包或模块。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

下面,我们通过torchvision的torchvision.datasets来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数train来指定获取训练数据集或测试数据集(testing data set)。测试数据集也叫测试集(testing set),只用来评价模型的表现,并不用来训练模型。

另外我们还指定了参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor

注意: 由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括transforms.ToTensor()在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成uint8,避免不必要的bug。 详见传送门2.2.4节。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

上面的mnist_trainmnist_test都是torch.utils.data.Dataset的子类,所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。训练集中和测试集中的每个类别的图像数分别为6,000和1,000。因为有10个类别,所以训练集和测试集的样本数分别为60,000和10,000。

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

输出:

<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000

我们可以通过下标来访问任意一个样本:

feature, label = mnist_train[0]
print(feature.shape, label)  # Channel x Height x Width

输出:

torch.Size([1, 28, 28]) tensor(9)

变量feature对应高和宽均为28像素的图像。由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0.0, 1.0]的32位浮点数。需要注意的是,feature的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一维是通道数,因为数据集中是灰度图像,所以通道数为1。后面两维分别是图像的高和宽。

Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

下面定义一个可以在一行里画出多张图像和对应标签的函数。

# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):d2l.use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()

现在,我们看一下训练数据集中前10个样本的图像内容和文本标签。

X, y = [], []
for i in range(10):X.append(mnist_train[i][0])y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

在这里插入图片描述

3.5.2 读取小批量

我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,mnist_traintorch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。

在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。

batch_size = 256
if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

我们将获取并读取Fashion-MNIST数据集的逻辑封装在d2lzh_pytorch.load_data_fashion_mnist函数中供后面章节调用。该函数将返回train_itertest_iter两个变量。随着本书内容的不断深入,我们会进一步改进该函数。它的完整实现将在5.6节中描述。

最后我们查看读取一遍训练数据需要的时间。

start = time.time()
for X, y in train_iter:continue
print('%.2f sec' % (time.time() - start))

输出:

1.57 sec

小结

  • Fashion-MNIST是一个10类服饰分类数据集,之后章节里将使用它来检验不同算法的表现。
  • 我们将高和宽分别为 h h h w w w像素的图像的形状记为 h × w h \times w h×w(h,w)

参考文献

[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/

[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.


注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2806438.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

论文是怎么一回事

最近找到女朋友了&#xff0c;她还挺关心我毕业和论文的事情&#xff0c;我开始着手弄论文了~ 说来惭愧&#xff0c;我一直以为读研就是做东西当作工作来完成&#xff0c;结果一直陷入如何实现的问题&#xff0c;结果要论文时不知道怎么弄创新点&#xff0c;这才转过头来弄论文…

信号系统之线性图像处理

1 卷积 图像卷积的工作原理与一维卷积相同。例如&#xff0c;图像可以被视为脉冲的总和&#xff0c;即缩放和移位的delta函数。同样&#xff0c;线性系统的特征在于它们如何响应脉冲。也就是说&#xff0c;通过它们的脉冲响应。系统的输出图像等于输入图像与系统脉冲响应的卷积…

重点媒体如何投稿?考核稿件投稿指南

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 机构组织&#xff0c;国企央企都需要定期将相关新闻投递到央媒&#xff0c;官媒&#xff0c;或者地方重点媒体中&#xff0c;那么如何进行投稿了&#xff0c;今天就与大家分享下。 央媒投…

喝多少瓶汽水

喝多少瓶汽水 题目描述&#xff1a;解法思路&#xff1a;解法代码&#xff1a;运行结果: 题目描述&#xff1a; 水已知1瓶汽水1元&#xff0c;2个空瓶可以换⼀瓶汽水&#xff0c;输入整数n&#xff08;n>0&#xff09;&#xff0c;表示n元钱&#xff0c;计算可以多少汽水&a…

《Docker 简易速速上手小册》第2章 容器和镜像(2024 最新版)

文章目录 2.1 理解 Docker 容器2.1.1 重点基础知识2.1.2 重点案例&#xff1a;使用 Docker 运行 Python 应用2.1.3 拓展案例 1&#xff1a;Docker 中的 Flask 应用2.1.4 拓展案例 2&#xff1a;Docker 容器中的数据分析 2.2 创建与管理 Docker 镜像2.2.1 重点基础知识2.2.2 重点…

力扣 169. 多数元素

思路&#xff1a; 因为题目说一定存在多数元素&#xff0c;就说明一定有一个数的个数多于n/2 将数组采用冒泡从小到大排序&#xff0c;最中间的那个元素一定是多数元素&#xff08;因为在大小排好序后&#xff0c;中位数也一定是众数&#xff09; 答案&#xff1a; int maj…

Xcode中App图标和APP名称的修改

修改图标 选择Assets文件 ——> 点击Applcon 换App图标 修改名称 点击项目名 ——> General ——> Display Name

C++中的STL数据结构

内容来自&#xff1a;代码随想录&#xff1a;哈希表理论基础 1.常见的三种哈希结构 当我们想使用哈希法来解决问题的时候&#xff0c;我们一般会选择如下三种数据结构 数组 set &#xff08;集合&#xff09; map(映射) 在C中&#xff0c;set 和 map 分别提供以下三种数据结构…

C++ 学习之Map容器

C++ Map容器概念 C++的Map容器是一种关联容器,它提供了一种将键和值相关联的方式。它以键值对的形式存储数据,并根据键的顺序自动进行排序。 Map中的键是唯一的,而值可以重复。你可以使用键来访问对应的值,就像使用索引访问数组中的元素一样。 Map容器的特点如下: 按照…

leetcode单调栈

739. 每日温度 请根据每日 气温 列表&#xff0c;重新生成一个列表。对应位置的输出为&#xff1a;要想观测到更高的气温&#xff0c;至少需要等待的天数。如果气温在这之后都不会升高&#xff0c;请在该位置用 0 来代替。 例如&#xff0c;给定一个列表 temperatures [73, …

29. 【Linux教程】Linux 用户介绍

本小节介绍 Linux 用户的基础知识&#xff0c;了解 Linux 系统中有哪些用户&#xff0c;如何查看当前 Linux 系统中有哪些用户&#xff0c;每一个 Linux 用户的权限取决于这些账号登录时获取到的权限。 1. Linux 用户类型 Linux 系统是一个多用户多任务的操作系统&#xff0c;…

服务内存优化思路

【需求背景】 由于某个服务在访问量不大的情况下&#xff0c;出现频繁的 full gc&#xff0c;配置内存为 1.5g&#xff0c;但是还是不够用&#xff0c;经常占用在 1.3g 左右。 【内存分析】 1、使用 MAT 对 dump 日志进行分析 通过 dominator-tree 可以看出 java.lang.Thread…

145.二叉树的后序遍历

// 定义一个名为Solution的类&#xff0c;用于解决二叉树的后序遍历问题 class Solution { // 定义一个公共方法&#xff0c;输入是一个二叉树的根节点&#xff0c;返回一个包含后序遍历结果的整数列表 public List<Integer> postorderTraversal(TreeNode root) { /…

【开源】JAVA+Vue.js实现考研专业课程管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 考研高校模块2.3 高校教师管理模块2.4 考研专业模块2.5 考研政策模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 考研高校表3.2.2 高校教师表3.2.3 考研专业表3.2.4 考研政策表 四、系统展示五、核…

高级统计方法 第4次作业

作业评阅&#xff1a; 概念 2.问题 KNN分类和KNN回归都是KNN算法在不同类型数据上的应用&#xff0c;但它们之间存在明显的区别。 解决的问题类型不同&#xff1a;KNN分类适用于解决分类问题&#xff0c;而KNN回归则适用于解决回归问题。当响应变量是连续的&#xff0c;根据…

我花了5天时间,开发了一个在线学习的小网站

大三寒假赋闲在家&#xff0c;闲来无事&#xff0c;用了5天时间做了一个在线学习的小网站&#xff0c;一鼓作气部署上线&#xff0c;制作的过程比较坎坷。内心经历过奔溃&#xff0c;也经历过狂喜。 按照惯例先放出网址&#xff0c;欢迎大家来访问学习&#xff1a;www.pbjlove…

大离谱!AI写作竟让孔子遗体现身巴厘岛,看完笑不活了

大家好&#xff0c;我是二狗。 这两天我在知乎上看到了一个AI写作大翻车的案例&#xff0c;看完简直笑不活了&#xff0c;特地分享给大家一起 happy happy&#xff5e; 知乎网友“打开盒子吓一跳”一上来就抛出来了一个“孔子去世”的王炸。 首先&#xff0c;下面是一条真实新…

每日一题——LeetCode1512.好数对的数目

方法一 暴力循环 var numIdenticalPairs function(nums) {let ans 0;for (let i 0; i < nums.length; i) {for (let j i 1; j < nums.length; j) {if (nums[i] nums[j]) {ans;}}}return ans; }; 消耗时间和内存情况&#xff1a; 方法二&#xff1a;组合计数 var …

智胜未来,新时代IT技术人风口攻略-第七版(弃稿)

文章目录 前言鸿蒙生态科普调研人员画像角色先行结论 - 市场下的增量蛋糕高校助力鸿蒙 - 掀起鸿蒙教育热潮高校鸿蒙课程开设占比 - 巨大需求背后是矛盾冲突教研力量并非唯一原因 - 看重教学成果复用与效率 企业布局规划 - 多元市场前瞻视野全盘接纳仍需一段时间 - 积极正向的一…

植物神经功能紊乱不治疗最坏后果会怎样?

植物神经功能紊乱是一种常见的疾病&#xff0c;它可以对人体的生理和心理产生严重的影响。如果不加以治疗&#xff0c;其最坏的后果将会是非常危险的。 植物神经功能紊乱是由于各种原因导致自主神经系统异常活跃或抑制而引起的一系列症状的总称。自主神经系统是负责自主调…