深度学习(15)--PyTorch构建卷积神经网络

目录

一.PyTorch构建卷积神经网络(CNN)详细流程

二.graphviz + torchviz使PyTorch网络可视化

2.1.可视化经典网络vgg16

2.2.可视化自己定义的网络


一.PyTorch构建卷积神经网络(CNN)详细流程

卷积神经网络(Convolutional Neural Networks)是一种深度学习模型或类似于人工神经网络的多层感知器,常用来分析视觉图像。

卷积神经网络的详细介绍可以参考博主写的文章:

深度学习(2)--卷积神经网络(CNN)-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/GodFishhh/article/details/135668789?spm=1001.2014.3001.5501

PyTorch构建神经网络的第一步均为引入神经网络包

import torch.nn as nn

卷积神经网络的构建: 

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层->激活函数->池化层self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)  pytorch中是channel_first的,颜色通道写在第一个位置nn.Conv2d(                      # 1d对结构化数据 2d对图像数据 3d对视频数据in_channels=1,              # 灰度图   输入的特征图数out_channels=16,            # 要得到几多少个特征图  输出的特征图数,也就是卷积核的个数(一个卷积核进行卷积可以得到一个特征图,所以卷积核的个数与特征图的数量相同)kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1  卷积后的图像大小: (h - Kernel_size + 2*p) / s + 1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)  池化后特征数变少)self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)  nn.ReLU(),                      # relu层nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果  最终数据的大小以及分类的数量def forward(self, x):# 调用卷积层x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7),分类无法对三维的数据进行处理,所以需要将三维图像拉长成一维数据再来进行分类.# -1是自动计算,只需给出一个维度的大小,会自动计算另外个维度.eg.5x4 -> x.view(2,-1),-1对应的就是10. 2x5x10 -> x.view(2,-1),-1对应的就是5x10# 在此处,给出的第一个参数x.size(0)的值为batch,所以-1对应的值就是32x7x7# 调用全连接层(全连接层的输入必须是二维的矩阵,上述的flattern操作将参数x变成了一个二维矩阵)output = self.out(x)return output

详解:

1.创建的神经网络构建类一定要继承nn.Module,后续要调用Module包里面的方法构建神经网络。

2.构造函数的第一步永远是调用父类的构造函数,利用super()进行调用:

super(CNN, self).__init__()

3.卷积神经网络的层次顺序一般为:卷积层-> 激活函数做非线性变换 ->池化层,并在输出之前设置一层全连接层。

4.上述代码构建的卷积神经网络是顺序Sequential的,设置有两个卷积层,两个激活函数,两个池化层,以及输出前的一个全连接层。(一般卷积一次就要池化一次)

nn.Sequential()

5.卷积层的构造:通过Module模块中的Conv2d来构造卷积层,其中参数分别为:输入图片数据的颜色通道数(第一个卷积层)/输入的特征图数(之后的卷积层)、输出的特征图数、卷积核的大小、步长、padding值。(其中Conv1d用来处理结构化数据,Conv2d用来处理图片数据,Conv3d用来处理视频数据)

此处设置的卷积层由输入的1个特征图数得到最后的32个特征图数

nn.Conv2d(1, 32, 5, 1, 2)
nn.Conv2d(16, 32, 5, 1, 2)

值得注意的是,如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1  卷积后的图像大小: (h - Kernel_size + 2*p) / s + 1。

6.此处激活函数设置的是ReLU,可以根据自己的需求设置不同的激活函数。

nn.ReLU()

7.池化层的构造: 只需要设置一个参数,即为进行池化操作的区域大小。

nn.MaxPool2d(kernel_size=2)

8.全连接层的构造:输入的数据最后经过全连接层得到输出数据,参数分别为输入数据的大小,以及最后进行分类的类别数。

self.out = nn.Linear(32 * 7 * 7, 10)

9.前向传播:PyTorch构建的神经网络,前向传播需要手动设置,此处先调用conv1和conv2两层,再将数据拉成二维的传入全连接层,得到最后的输出值。

二.graphviz + torchviz使PyTorch网络可视化

事先需要先安装graphviz库和torchviz库,graphviz具体安装步骤可以参考博主写的文章:

深度学习(9)--pydot库和graphviz库安装流程详解_pydot 怎么安装-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/GodFishhh/article/details/135929146?spm=1001.2014.3001.5501torchviz库可以直接再编译器中进行安装,也可也在cmd中对应环境中使用pip指令安装:

上述两个库安装完之后,导入网络可视化需要用到的头文件:

from torchviz import make_dot
from torchvision.models import vgg16  # 导入vgg16模型用于演示

2.1.可视化经典网络vgg16

# 随机生成一个tensor张量(对应的数据为图片有十张,图片的大小为3x32x32)
x = torch.randn(10, 3, 32, 32)
# 实例化 vgg16
model = vgg16()
# 将 x 输入网络
vgg16_out = model(x)
# 实例化 make_dot
vgg16_result = make_dot(vgg16_out)
# result.view()  直接在当前路径下保存 pdf 并打开
# 保存文件为pdf到指定路径并不打开
vgg16_result.render(filename='vgg16_net_Structure', view=False, format='pdf')

生成如下两个文件 

 

2.2.可视化自己定义的网络

# 随机生成一个tensor张量(对应的数据为图片有四张,图片的大小为1x28x28)
x = torch.randn(4, 1, 28, 28)
# 实例化 vgg16
model = CNN()
# 将 x 输入网络
CNN_out = model(x)
# 实例化 make_dot
CNN_result = make_dot(CNN_out)
# result.view()  直接在当前路径下保存 pdf 并打开
# 保存文件为pdf到指定路径并不打开
CNN_result.render(filename='CNN_net_Structure', view=False, format='pdf')

生成如下两个文件  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2774376.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot源码解读与原理分析(七)BeanFactory

文章目录 3 SpringBoot的IOC容器3.1 SpringFramework的IOC容器3.1.1 BeanFactory3.1.1.1 BeanFactory根接口3.1.1.2 HierarchicalBeanFactory3.1.1.3 ListableBeanFactory3.1.1.4 AutowireCapableBeanFactory3.1.1.5 ConfigurableBeanFactory3.1.1.6 AbstractBeanFactory3.1.1.…

机器学习之指数分布

指数分布: 指数分布可以用来表示独立随机事件发生的时间间隔。如果一个随机变量X的概率密度函数满足以下形式,就称X服从参数λ的指数分布,记作X ~ E(λ)或X~Exp(λ)。指数分布只有一个指数参数,且λ>0&a…

SolidWorks学习笔记——入门知识2

目录 建出第一个模型 1、建立草图 2、选取中心线 3、草图绘制 4、拉伸 特征的显示与隐藏 改变特征名称 5、外观 6、渲染 建出第一个模型 1、建立草图 图1 建立草图 按需要选择基准面。 2、选取中心线 图2 选取中心线 3、草图绘制 以对称图形举例,先画出…

【GAMES101】Lecture 18 高级光线传播

这节课不涉及数学原理,只讲流程操作,大家当听这个十万个为什么就行 目录 高级光线传播 无偏光线传播方法 双向路径追踪(Bidirectional path tracing) Metropolis light transport (MLT) 有偏光线传播方法 光子映射(Photon …

FXTM富拓监管变更!2024开年连续3家交易商注销牌照

交易商的监管信息是经常发生变更的,即使第一次投资时查询平台监管牌照,投资者仍需持续关注其监管动态。千万不要以为第一步审核好后就万事大吉了! 2024年开年,就有3家交易商的重要信息发生变更,注销其金融监管牌照&…

按键扫描16Hz-单片机通用模板

按键扫描16Hz-单片机通用模板 一、按键扫描的原理1、直接检测高低电平类型2、矩阵扫描类型3、ADC检测类型二、---.c的实现1、void keyScan(void) 按键扫描函数①void FHiKey(void) 按键按下功能②void FSameKey(void) 按键长按功能③void FLowKey(void) 按键释放功能三、key.h的…

Leetcode—135. 分发糖果【中等】

2024每日刷题(113) Leetcode—135. 分发糖果 算法思想 这里可以利用贪心策略,求局部最优解,然后合并为全局最优解。具体来说,将原问题中相邻孩子的条件划分为左相邻孩子和右相邻孩子两个条件,依次求解出两…

Phobos捆绑某数控软件AdobeIPCBroker组件定向勒索

前言 Phobos勒索病毒最早于2019年被首次发现并开始流行起来,该勒索病毒的勒索提示信息特征与CrySiS(Dharma)勒索病毒非常相似,但是两款勒索病毒的代码特征却是完全不一样,近日笔者在逛某开源恶意软件沙箱的时候发现了一款Phobos勒索病毒捆绑…

应用层DoS

应用层(application layer)是七层OSI模型的第七层。应用层直接和应用程序 对接并提供常见的网络应用服务,能够在实现多个系统应用进程相互通信的同 时,完成一系列业务处理所需的服务。位于应用层的协议有很多,常见的包…

【已解决】:pip is configured with locations that require TLS/SSL

在使用pip进行软件包安装的时候出现问题: WARNING: pip is configured with locations that require TLS/SSL, however the ssl module in Python is not available. 解决: mkdir -p ~/.pip vim ~/.pip/pip.conf然后输入内容: [global] ind…

07-OpenFeign-HTTP压缩优化

gzip是一种数据格式,采用用deflate算法压缩数据;gzip是一种流行的数据压缩算法,应用十分广泛,尤其是在Linux平台。 当GZIP压缩到一个纯文本数据时,效果是非常明显的,大约可以减少70%以上的数据…

第九个知识点:内部对象

Date对象: <script>var date new Date();date.getFullYear();//年date.getMonth();//月date.getDate();//日date.getDay();//星期几date.getHours();//时date.getMinutes();//分date.getSeconds();//秒date.getTime();//获取时间戳&#xff0c;时间戳时全球统一&#x…

C++力扣题目494--目标和 474--一和零

494.目标和 力扣题目链接(opens new window) 难度&#xff1a;中等 给定一个非负整数数组&#xff0c;a1, a2, ..., an, 和一个目标数&#xff0c;S。现在你有两个符号 和 -。对于数组中的任意一个整数&#xff0c;你都可以从 或 -中选择一个符号添加在前面。 返回可以使…

Backtrader 文档学习- Plotting

Backtrader 文档学习- Plotting 虽然回测是一个基于数学计算的自动化过程&#xff0c;还是希望实际通过可视化验证。无论是使用现有算法回测&#xff0c;还是观察数据驱动的指标&#xff08;内置或自定义&#xff09;。 凡事都要有人完成&#xff0c;绘制数据加载、指标、操作…

PostgreSql与Postgis安装

POstgresql安装 1.登录官网 PostgreSQL: Linux downloads (Red Hat family) 2.选择版本 3.安装 ### 源 yum install -y https://download.postgresql.org/pub/repos/yum/reporpms/EL-7-x86_64/pgdg-redhat-repo-latest.noarch.rpm ### 客户端 yum install postgresql14 ###…

Java面向对象 继承

目录 继承继承的好处继承具有传递性实例创建Person类Student继承Person类测试 继承 Java中的继承是面向对象编程的一个核心特性&#xff0c;它允许一个类&#xff08;子类或派生类&#xff09;继承另一个类&#xff08;父类或基类&#xff09;的属性和方法。通过继承&#xff0…

【HarmonyOS应用开发】HTTP数据请求(十四)

文章末尾含相关内容源代码 一、概述 日常生活中我们使用应用程序看新闻、发送消息等&#xff0c;都需要连接到互联网&#xff0c;从服务端获取数据。例如&#xff0c;新闻应用可以从新闻服务器中获取最新的热点新闻&#xff0c;从而给用户打造更加丰富、更加实用的体验。 那么…

【Spring】GoF 之工厂模式

一、GoF 23 设计模式简介 设计模式&#xff1a;一种可以被重复利用的解决方案 GoF&#xff08;Gang of Four&#xff09;&#xff0c;中文名——四人组 《Design Patterns: Elements of Reusable Object-Oriented Software》&#xff08;即《设计模式》一书&#xff09;&…

JavaEE作业-实验三

目录 1 实验内容 2 实验要求 3 思路 4 核心代码 5 实验结果 1 实验内容 简单的线上图书交易系统的web层 2 实验要求 ①采用SpringMVC框架&#xff0c;采用REST风格 ②要求具有如下功能&#xff1a;商品分类、订单、购物车、库存 ③独立完成&#xff0c;编写实验报告 …

Unity2D 学习笔记 0.Unity需要记住的常用知识

Unity2D 学习笔记 0.Unity需要记住的常用知识 前言调整Project SettingTilemap相关&#xff08;创建地图块&#xff09;C#脚本相关程序运行函数private void Awake()void Start()void Update() Collider2D碰撞检测private void OnTriggerStay2D(Collider2D player)private void…