【深度学习入门篇 ⑧】关于卷积神经网络

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


关于卷积神经网络,你还有哪些不知道的知识点呢,之前我们介绍了大部分,今天再来补充一下~

卷积神经网络基础

什么是卷积

Convolution,输入信息与核函数(滤波器)的乘积

  • 一维信号的时间卷积:输入x,核函数w,输出是一个连续时间段t的加权平均结果。
  • 二维图像的空间卷积:输入图像I,卷积核K,输出图像O。

单个二维图片卷积 :输入为单通道图像,输出为单通道图像。

图像的数据存储是多通道的二维矩阵:

灰度图(Gray)只有一个通道(一层),RGB彩色图就是三个通道(Red,Green,Blue),而RGBA彩色图就是四个通道(Red,Green,Blue,Alpha)。

如何表达每一个网络层中高维的图像数据?

特征图包含:通道,宽度,高度,其中输入特征图Ci,输出特征图C0,输出特征图的每一个通道,由输入图的所有通道和相同数量的卷积核先一一对应各自进行卷积计算,然后求和

 

卷积相关操作与参数

填充

padding :给卷积前的输入图像边界添加额外的行,列

  • 控 制 卷 积 后 图 像分 辨 率 , 方便计算特征图尺寸的变化
  • 弥 补 边 界 信 息 “丢 失 ” 

步长

步长(stride):卷积核在图像上移动的步子

卷积的核心思想 

为什么要进行局部连接?

  • 局部连接可以更好地利用图像中的结构信息,空间距离越相近的像素其相互影响越大

权重共享:保证不变性,图像从一个局部区域学习到的信息应用到其他区域 ,减少参数,降低学习难度。

ANN与CNN比较

传统神经网络为有监督的机器学习,输入为特征;卷积神经网络为无监督特征学习,输入为最原始的图像。


案例-图像分类

CIFAR10 数据集

CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32×32×3。

 PyTorch 中的 torchvision.datasets 计算机视觉模块封装了 CIFAR10 数据集:

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoaderdef func1():# 加载数据集train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))print('训练集数量:', len(train.targets))print('测试集数量:', len(valid.targets))print("数据集形状:", train[0][0].shape)print("数据集类别:", train.class_to_idx)# 数据加载器
def func2():train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))dataloader = DataLoader(train, batch_size=8, shuffle=True)for x, y in dataloader:print(x.shape)print(y)breakif __name__ == '__main__':func1()func2()

我们要搭建的网络结构:

我们在每个卷积计算之后应用 relu 激活函数来给网络增加非线性因素。

网络代码实现:

class ImageClassification(nn.Module):def __init__(self):super(ImageClassification, self).__init__()self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.linear1 = nn.Linear(576, 120)self.linear2 = nn.Linear(120, 84)self.out = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)x = x.reshape(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))return self.out(x)

 编写训练函数

训练时,使用多分类交叉熵损失函数,Adam 优化器:

def train():transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transgform)model = ImageClassification()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练轮数epoch = 100for epoch_idx in range(epoch):# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 样本数量sam_num = 0# 损失总和total_loss = 0.0# 开始时间start = time.time()correct = 0for x, y in dataloader:# 送入模型output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()correct += (torch.argmax(output, dim=-1) == y).sum()total_loss += (loss.item() * len(y))sam_num += len(y)print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %(epoch_idx + 1,total_loss / sam_num,correct / sam_num,time.time() - start))torch.save(model.state_dict(), 'model/image_classification.bin')

编写预测函数

我们加载训练好的模型,对测试集中的 1 万条样本进行预测,查看模型在测试集上的准确率

def test():transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transgform)# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 加载模型model = ImageClassification()model.load_state_dict(torch.load('model/image_classification.bin'))model.eval()total_correct = 0total_samples = 0for x, y in  dataloader:output = model(x)total_correct += (torch.argmax(output, dim=-1) == y).sum()total_samples += len(y)print('Acc: %.2f' % (total_correct / total_samples))

输出:

'Acc: 0.61

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3246720.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

昇思25天学习打卡营第13天|CycleGAN 图像风格迁移互换全流程解析

目录 数据集下载和加载 可视化 构建生成器 构建判别器 优化器和损失函数 前向计算 计算梯度和反向传播 模型训练 模型推理 数据集下载和加载 使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install dow…

el-table表格操作列错行处理

解决方法&#xff1a; <style>::v-deep .el-table th.el-table__cell > .cell {white-space: nowrap !important;} </style>

Template execution failed: ReferenceError: name is not defined

问题 我们使用了html-webpack-plugin&#xff08;webpack&#xff09;进行编译html&#xff0c;导致的错误。 排查结果 连接地址 html-webpack-plugin版本低(2.30.1)&#xff0c;html模板里面不能有符号&#xff0c;注释都不行 // var reg new RegExp((^|&)${name}([^&…

测试——基础篇

内容纲要: 软件测试的生命周期 如何描述一个bug 如何定义bug的级别 bug的生命周期 如何开始第一次测试 测试的执行和bug管理 产生争执怎么办 1. 软件测试的生命周期 需求分析→测试计划→ 测试设计、测试开发→ 测试执行→ 测试评估 需求分析:需求是否正确,是否可行测试计划:…

Qt基础 | Qt全局定义 | qglobal头文件中的数据类型、函数、宏定义

文章目录 一、数据类型定义二、函数三、宏定义 QtGlobal头文件包含了 Qt 类库的一些全局定义 &#xff0c;包括基本数据类型、函数和宏&#xff0c;一般的Qt类的头文件都会包含该文件。 详细内容可参考&#xff1a;https://doc.qt.io/qt-5/qtglobal.html 一、数据类型定义 为了…

《Python机器学习项目实战》书籍介绍

文章目录 书籍介绍主要内容书籍目录 书籍介绍 《Python机器学习项目实战》带领大家在构建实际项目的过程中&#xff0c;掌握关键的机器学习概念&#xff01;使用机器学习&#xff0c;我们可完成客户行为分析、价格趋势预测、风险评估等任务。要想掌握机器学习&#xff0c;需要…

2024年大数据高频面试题(中篇)

文章目录 Kafka为什么要用消息队列为什么选择了kafkakafka的组件与作用(架构)kafka为什么要分区Kafka生产者分区策略kafka的数据可靠性怎么保证ack应答机制(可问:造成数据重复和丢失的相关问题)副本数据同步策略ISRkafka的副本机制kafka的消费分区分配策略Range分区分配策略…

网络准入控制设备是什么?有哪些?网络准入设备臻品优选

小李&#xff1a;“小张&#xff0c;最近公司网络频繁遭遇外部攻击&#xff0c;我们得加强一下网络安全了。” 小张&#xff1a;“是啊&#xff0c;我听说实施网络准入控制是个不错的选择。但具体什么是网络准入控制设备&#xff1f;我们有哪些选择呢&#xff1f;” 小李微笑…

数据结构历年考研真题对应知识点(哈夫曼树和哈夫曼编码)

目录 5.5.1哈夫曼树和哈夫曼编码 1.哈夫曼树的定义 2.哈夫曼树的构造 【分析哈夫曼树的路径上权值序列的合法性&#xff08;2010&#xff09;】 【哈夫曼树的性质&#xff08;2010、2019&#xff09;】 3.哈夫曼编码 【根据哈夫曼编码对编码序列进行译码&#xff08;201…

乘积量化pq:将高维向量压缩 97%

向量相似性搜索在处理大规模数据集时&#xff0c;往往面临着内存消耗的挑战。例如&#xff0c;即使是一个包含100万个密集向量的小数据集&#xff0c;其索引也可能需要数GB的内存。随着数据集规模的增长&#xff0c;尤其是高维数据&#xff0c;内存使用量会迅速增加&#xff0c…

达梦 ./disql SYSDBA/SYSDBA报错[-70028]:创建SOCKET连接失败. 解决方法

原因 达梦命令./disql SYSDBA/SYSDBA默认访问端口5236&#xff0c;如果初始化实例的时候修改了端口&#xff0c;需要指定端口访问 解决 ./disql SYSDBA/SYSDBA192.168.10.123:5237

数据结构(5.2_1)——二叉树的基本定义和术语

二叉树的基本概念 二叉树是n(n>0)个结点的有限集合: 或者为空二叉树&#xff0c;即n0&#xff1b;或者由一个根结点和两个互不相交的被称为根的左子树和右子树组成。左子树和右子树又分别是一颗二叉树。 特点:每个结点至多只有两颗字数&#xff1b;左子树不能颠倒(二叉树…

09 深度推荐模型演化中的“平衡与不平衡“规律

你好&#xff0c;我是大壮。08 讲我们介绍了深度推荐算法中的范式方法&#xff0c;并简单讲解了组合范式推荐方法&#xff0c;其中还提到了多层感知器&#xff08;MLP&#xff09;。因此&#xff0c;这一讲我们就以 MLP 组件为基础&#xff0c;讲解深度学习范式的其他组合推荐方…

pico+unity3d手部动画

在 Unity 开发中&#xff0c;输入系统的选择和运用对于实现丰富的交互体验至关重要。本文将深入探讨 Unity 中的 Input System 和 XR Input Subsystem 这两种不同的输入系统&#xff0c;并详细介绍它们在控制手部动画方面的应用。 一、Input System 和 XR Input Subsystem 的区…

每日练习,不要放弃

目录 题目1.下面叙述错误的是 ( )2.java如何返回request范围内存在的对象&#xff1f;3.以下代码将打印出4.下列类定义中哪些是合法的抽象类的定义&#xff1f;&#xff08;&#xff09;5.以下代码段执行后的输出结果为6.以下代码运行输出的是总结 题目 选自牛客网 1.下面叙述…

【node-RED 4.0.2】连接操作 Oracle 数据库实现 增 删 改 查【新版,使用新插件:@hylink/node-red-oracle】

总览 上节课&#xff0c;我们说到&#xff0c;在 node-red 上链接 oracle 数据库 我们使用的插件是 node-red-contrib-agur-connector。 其实后来我发现&#xff0c;有一个插件更简便&#xff0c;并且也更好用&#xff1a;hylink/node-red-oracle &#xff01;&#xff01;&am…

Linux--网络基础

计算机网络背景 计算机网络背景是一个复杂而丰富的领域&#xff0c;涵盖了从计算机单机模式到网络互联的演变过程&#xff0c;以及网络技术的不断发展和创新。 计算机单机模式和独立发展 在早期&#xff0c;计算机主要以单机模式存在&#xff0c;即每台计算机都是独立的&…

传知代码-揭秘AI如何揪出图片中的“李鬼”(论文复现)

代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 文字篡改图像的“照妖镜”&#xff1a;揭秘AI如何揪出图片中的“李鬼” 在数字化时代&#xff0c;我们时常被各种图像信息所包围。然而&#xff0c;这些图像中有时隐藏着不为人知的秘密——被篡改的文字或图像。这…

C++ | Leetcode C++题解之第238题除自身以外数组的乘积

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<int> productExceptSelf(vector<int>& nums) {int length nums.size();// L 和 R 分别表示左右两侧的乘积列表vector<int> L(length, 0), R(length, 0);vector<int> answer(l…