【深度学习入门篇 ⑨】循环神经网络实战

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


今天我们看一下用循环神经网络RNN的原理并且动手应用到案例。

3e012755cfd647aebdf70ff24536d38b.png 

循环神经网络

在普通的神经网络中,信息的传递是单向的,这种限制虽然使得网络变得更容易学习,但在一定程度上也减弱了神经网络模型的能力。特别是在很多现实任务中,网络的输出不仅和当前时刻的输入相关,也和其过去一段时间的输出相关。此外,普通网络难以处理时序数据,比如视频、语音、文本等,时序数据的长度一般是不固定的,而前馈神经网络要求输入和输出的维数都是固定的,不能任意改变。因此,当处理这一类和时序相关的问题时,就需要一种能力更强的模型。

循环神经网络 (RNN)是一类具有短期记忆能力的神经网络。在循环神经网络中,神经元不但可以接受其它神经元的信息,也可以接受自身的信息,形成具有环路的网络结构。  

ab119b30479c4d74bb10bf02ef0d9f34.png 

RNN比传统的神经网络多了一个循环圈,这个循环表示的就是在下一个时间步上会返回作为输入的一部分,我们把RNN在时间点上展开 :

6e2096802ad346c1836d1ede9370a9fe.png

在不同的时间步,RNN的输入都将与之前的时间状态有关 ,具体来说,每个时间步的RNN单元都会接收两个输入:当前时间步的外部输入和前一时间步(隐藏层)的输出状态。通过这种方式,RNN能够学习并理解数据中的长期依赖关系,使得它在处理文本生成、语音识别、时间序列预测等序列数据时表现尤为出色。

此外,RNN的隐藏状态(或称为内部状态)在每次迭代时都会更新,这种更新过程包含了当前输入和前一时间步状态的非线性组合,使得网络能够动态地调整其对序列中接下来内容的预测或理解。

d1ad2acff14b48458791021e8ce8eaa5.png

LSTM和GRU

传统的RNN在处理长序列数据时常常面临梯度消失或梯度爆炸的问题,这限制了其在处理长期依赖关系上的能力。为了克服这一局限性,LSTM(Long Short-Term Memory,长短期记忆网络)作为RNN的一种变体被引入。

LSTM是一种RNN特殊的类型,可以学习长期依赖信息。在很多问题上,LSTM都取得相当巨大的成功,并得到了广泛的应用。

48465d18371741739f23324e0f1f3e05.png

LSTM是通过一个叫做的结构实现,门可以选择让信息通过或者不通过。 这个门主要是通过sigmoid和点乘实现的 ;sigmoid 的取值范围是在(0,1)之间,如果接近0表示不让任何信息通过,如果接近1表示所有的信息都会通过。

  • 遗忘门通过sigmoid函数来决定哪些信息会被遗忘
  • 输入门决定哪些新的信息会被保留。

例如:

我昨天吃了拉面,今天我想吃炒饭,在这个句子中,通过遗忘门可以遗忘拉面,同时更新新的主语为炒饭。

输出门

我们需要决定什么信息会被输出,也是一样这个输出经过变换之后会通过sigmoid函数的结果来决定那些细胞状态会被输出。

  1. 前一次的输出和当前时间步的输入的组合结果通过sigmoid函数进行处理得到O_t

  2. 更新后的细胞状态C_t会经过tanh层的处理,把数据转化到(-1,1)的区间

  3. tanh处理后的结果和O_t进行相乘,把结果输出同时传到下一个LSTM的单元

8ca0b205bcfa44e18c3af5b4f7271880.png 

GRU

GRU是一种LSTM的变形版本, 它将遗忘和输入门组合成一个“更新门”。它还合并了单元状态和隐藏状态,并进行了一些其他更改,由于他的模型比标准LSTM模型简单,所以越来越受欢迎。

664e50357e604f918c707643ca15bc9c.png

b429639b6a994ec099f87d8adf609263.png 

双向LSTM

单向的 RNN,是根据前面的信息推出后面的,但有时候只看前面的词是不够的, 可能需要预测的词语和后面的内容也相关,那么此时需要一种机制,能够让模型不仅能够从前往后的具有记忆,还需要从后往前需要记忆。此时双向LSTM就可以帮助我们解决这个问题

f990226c2e3a4c9da262cc74ff2201e4.png 

由于是双向LSTM,所以每个方向的LSTM都会有一个输出,最终的输出会有2部分,所以往往需要concat的操作。

96f81f98d8e74dadaa1f4925a3406007.pngRNN实现文本情感分类 

torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first,dropout,bidirectional)
  1. input_size:输入数据的形状,即embedding_dim

  2. hidden_size:隐藏层神经元的数量,即每一层有多少个LSTM单元

  3. num_layer :即RNN的中LSTM单元的层数

  4. batch_first:默认值为False,输入的数据需要[seq_len,batch,feature],如果为True,则为[batch,seq_len,feature]

  5. dropout:dropout的比例,默认值为0。dropout是一种训练过程中让部分参数随机失活的一种方式,能够提高训练速度,同时能够解决过拟合的问题。

  6. bidirectional:是否使用双向LSTM,默认是False

实例化LSTM对象之后,不仅需要传入数据,还需要前一次的h_0(前一次的隐藏状态)和c_0

LSTM的默认输出为output, (h_n, c_n)  

  1. output(seq_len, batch, num_directions * hidden_size)--->batch_first=False

  2. h_n:(num_layers * num_directions, batch, hidden_size)

  3. c_n: (num_layers * num_directions, batch, hidden_size)

 4b9843ea2e35484f86a90641afd0fff6.png

LSTM和GRU的使用注意点

  1. 第一次调用之前,需要初始化隐藏状态,如果不初始化,默认创建全为0的隐藏状态

  2. 往往会使用LSTM or GRU 的输出的最后一维的结果,来代表LSTM、GRU对文本处理的结果,其形状为[batch, num_directions*hidden_size]

使用LSTM完成文本情感分类

class IMDBLstmmodel(nn.Module):def __init__(self):super(IMDBLstmmodel,self).__init__()self.hidden_size = 64self.embedding_dim = 200self.num_layer = 2self.bidriectional = Trueself.bi_num = 2 if self.bidriectional else 1self.dropout = 0.5self.embedding = nn.Embedding(len(ws),self.embedding_dim,padding_idx=ws.PAD) #[N,300]self.lstm = nn.LSTM(self.embedding_dim,self.hidden_size,self.num_layer,bidirectional=True,dropout=self.dropout)self.fc = nn.Linear(self.hidden_size*self.bi_num,20)self.fc2 = nn.Linear(20,2)def forward(self, x):x = self.embedding(x)x = x.permute(1,0,2) h_0,c_0 = self.init_hidden_state(x.size(1))_,(h_n,c_n) = self.lstm(x,(h_0,c_0))out = torch.cat([h_n[-2, :, :], h_n[-1, :, :]], dim=-1)out = self.fc(out)out = F.relu(out)out = self.fc2(out)return F.log_softmax(out,dim=-1)def init_hidden_state(self,batch_size):h_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)c_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)return h_0,c_0

为了提高程序的运行速度,可以考虑把模型放在GPU上运行:

  1. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  2. model.to(device)

train_batch_size = 64
test_batch_size = 5000
imdb_model = IMDBLstmmodel().to(device) 
optimizer = optim.Adam(imdb_model.parameters())
criterion = nn.CrossEntropyLoss()def train(epoch):mode = Trueimdb_model.train(mode)train_dataloader =get_dataloader(mode,train_batch_size)for idx,(target,input,input_lenght) in enumerate(train_dataloader):target = target.to(device)input = input.to(device)optimizer.zero_grad()output = imdb_model(input)loss = F.nll_loss(output,target) loss.backward()optimizer.step()if idx %10 == 0:pred = torch.max(output, dim=-1, keepdim=False)[-1]acc = pred.eq(target.data).cpu().numpy().mean()*100.print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t ACC: {:.6f}'.format(epoch, idx * len(input), len(train_dataloader.dataset),100. * idx / len(train_dataloader), loss.item(),acc))torch.save(imdb_model.state_dict(), "model/mnist_net.pkl")torch.save(optimizer.state_dict(), 'model/mnist_optimizer.pkl')def test():mode = Falseimdb_model.eval()test_dataloader = get_dataloader(mode, test_batch_size)with torch.no_grad():for idx,(target, input, input_lenght) in enumerate(test_dataloader):target = target.to(device)input = input.to(device)output = imdb_model(input)test_loss  = F.nll_loss(output, target,reduction="mean")pred = torch.max(output,dim=-1,keepdim=False)[-1]correct = pred.eq(target.data).sum()acc = 100. * pred.eq(target.data).cpu().numpy().mean()print('idx: {} Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(idx,test_loss, correct, target.size(0),acc))if __name__ == "__main__":test()for i in range(10):train(i)test()

然后由大家写代码得到模型训练的最终输出,大家可以改变模型来观察不同的结果。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3248822.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

运算符的使用

一、运算符介绍 运算符是一种特殊的符号,用以表示数据的运算、赋值和比较等 算术运算符赋值运算符比较运算符逻辑运算符位运算符 二、算术运算符 1、算术运算符是对数值类型的变量进行运算的,在程序中使用的非常多 2、算术运算符的使用 # 算术运算符…

Learning vtkjs之vtkSource

vtkSource的主要类型 Cone 锥体Circle 圆形Arrow 箭头ConcentricCylinder 同心圆Cube 方形Cursor3D 包围盒Cylinder 圆柱体Line 线Plane 平面Point 点Sphere 球不能调整center的source 目前整理的有下面几种source,对应有点类似threejs的mesh,通过一定的…

【.NET全栈】ASP.NET开发Web应用——站点导航技术

文章目录 前言一、站点地图1、定义站点地图文件2、使用SiteMapPath控件3、SiteMap类4、URL地址映射 二、TreeView控件1、使用TreeView控件2、以编程的方式添加节点3、使用TreeView控件导航4、绑定到XML文件5、按需加载节点6、带复选框的TreeView控件 三、Menu控件1、使用Menu控…

C语言指针超详解——进阶篇

C语言指针系列文章目录 入门篇 强化篇 进阶篇 文章目录 C语言指针系列文章目录1. 字符指针变量2. 数组指针变量2. 1 概念2. 2 数组指针变量的初始化 3. 二维数组传参的本质4. 函数指针变量4. 1 函数指针变量的创建4. 2 指针变量的使用4. 3 两个有趣的代码4. 3. 1 代码一4. 3. …

c++初阶知识——内存管理与c语言内存管理对比

目录 前言: 1.c++内存管理方式 1.1 new和delete操作自定义类型 2.operator new与operator delete函数 2.1 operator new与operator delete函数 3.new和delete的实现原理 3.1 内置类型 3.2 自定义类型 new的原理 delete的原理 new…

完整教程 linux下安装百度网盘以及相关依赖库,安装完成之后启动没反应 或者 报错

完整教程 linux下安装百度网盘以及相关依赖库,安装完成之后启动没反应 或者 报错。 配置国内镜像源: yum -y install wget mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.bak wget -O /etc/yum.repos.d/CentOS-Base.repo ht…

数据库端口LookUp功能:从数据库中获取并添加数据到XML

本文将为大家介绍如何使用知行之桥EDI系统数据库端口的Lookup功能,从数据库中获取数据,并添加进输入的XML中。 使用场景:期待以输入xml中的值为判断条件从数据库中获取数据,并添加进输入xml中。 例如:接收到包含采购…

pyqt/pyside QTableWidget失去焦点后,选中的行仍高亮的显示

正常情况下pyqt/pyside的QTableWidget,点击input或者按钮失去焦点后 行的颜色消失了 如何在失去焦点时保持行的选中颜色,增加下面的代码: # 获取当前表格部件的调色板 p tableWidget.palette()# 获取活跃状态下的高亮颜色和高亮文本颜色&a…

AWS-S3实现Minio分片上传、断点续传、秒传、分片下载、暂停下载

文章目录 前言一、功能展示上传功能点下载功能点效果展示 二、思路流程上传流程下载流程 三、代码示例四、疑问 前言 Amazon Simple Storage Service(S3),简单存储服务,是一个公开的云存储服务。Web应用程序开发人员可以使用它存…

HZNUCTF2023中web相关题目

[HZNUCTF 2023 preliminary]guessguessguess 这道题目打不开了 [HZNUCTF 2023 preliminary]flask 这道题目考察SSTI倒序的模板注入,以及用env命令获得flag 看题目,猜测是SSTI模板注入,先输入{7*7},发现模板是倒序输入的 输入}}7*7{{返回77…

Postgresql主键自增的方法

Postgresql主键自增的方法 一.方法(一) 使用 serial PRIMARY KEY 插入数据 二.方法(二) 🎈边走、边悟🎈迟早会好 一.方法(一) 使用 serial PRIMARY KEY 建表语句如下&#xf…

学生管理系统(C语言)(Easy-x)

课 程 报 告 课 程 名 称: 程序设计实践 专 业 班 级 : XXXXX XXXXX 学 生 姓 名 : XXX 学 号 : 231040700302 任 课 教 师 &a…

C++类与对象(补)

感谢大佬的光临各位,希望和大家一起进步,望得到你的三连,互三支持,一起进步 个人主页:LaNzikinh-CSDN博客 文章目录 前言一.默认成员函数二.static三.友元四.匿名对象总结 前言 类的默认成员函数,默认成员…

Mongodb数据库(上)

介绍 是一个基于磁盘存储的开源的、文档类型(数据存储格式)的非关系型数据库。 其数据首先是存放到内存中,当内存不够时,它还可以存放到磁盘里面去 优点 基本概念 数据库 mongodb中的数据库默认是’test‘(就是一进去就是直接使用用的test数据库),如果想要使用其他…

【LabVIEW作业篇 - 2】:分数判断、按钮控制while循环暂停、单击按钮获取book文本

文章目录 分数判断按钮控制while循环暂停按钮控制单个while循环暂停 按钮控制多个while循环暂停单击按钮获取book文本 分数判断 限定整型数值输入控件值得输入范围,范围在0-100之间,判断整型数值输入控件的输入值。 输入范围在0-59之间,显示…

【Python进阶】正则表达式、pymysql模块

目录 一、正则表达式的概述 1、基本介绍 2、快速使用re模块 二、正则的常见规则 1、匹配单个字符 2、原始字符串 3、匹配多个字符 4、匹配开头和结尾 5、匹配分组 三、Python与MySQL交互 1、pymysql模块的安装 2、pymysql的操作步骤 3、connection对象 4、cursor…

基于ANSIBLE中的YAML非标记语言Role角色扮演

YAML-YAML Ain’t Markup Language-非标记语言 语法 列表 fruits:​ - Apple​ - Orange​ - Strawberry​ - Mango 字典 martin:​ name : Martin D’vloper​ job : Developer​ skill : Elite 示例1 需求 通过YAML编写一个简单的剧本,完成web的部署&#xff0c…

【Mongodb-04】Mongodb聚合管道操作基本功能

Mongodb系列整体栏目 内容链接地址【一】Mongodb亿级数据性能测试和压测https://zhenghuisheng.blog.csdn.net/article/details/139505973【二】springboot整合Mongodb(详解)https://zhenghuisheng.blog.csdn.net/article/details/139704356【三】亿级数据从mysql迁移到mongodb…

【Springboot】新增profile环境配置应用启动失败

RT 最近接手了一个新的项目,为了不污染别人的环境,我新增了一个自己的环境配置。结果,在启动的时候总是失败,就算是反复mvn clean install也是无效。 问题现象 卡住无法进行下一步 解决思路 由于之前都是能启动的&#xff0c…

视频素材网站无水印的有哪些?热门视频素材网站分享

当我们走进视频创作的精彩世界时,一个难题常常摆在面前——那些高品质、无水印的视频素材究竟应该在哪里寻找?许多视频创作者感叹,寻找理想的视频素材难度甚至超过了寻找伴侣!但不用担心,今天我将为您介绍几个优质的视…