Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

tips:安装依赖库

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install tqdm requests

一、RNN模型构建

数据集准备完成了输入文本通过查字典(序列化)的向量化。并使用nn.Embedding层加载了Glove词向量。下一步将使用RNN循环神经网络做特征提取,最后将RNN连接至全连接网络nn.Dednse,将特征转化为分类。

nn.Embedding -> nn.RNN -> nn.Dense

本项目,采用规避RNN梯度消的变种LSTM(Long short-term memory)代替RNN做特征提取层。

1.1 关于RNN

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。下图为RNN的一般结构:

RNN-0

图示左侧为一个RNN Cell循环,右侧为RNN的链式连接平铺。实际上不管是单个RNN Cell还是一个RNN网络,都只有一个Cell的参数,在不断进行循环计算中更新。

由于RNN的循环特性,和自然语言文本的序列特性(句子是由单词组成的序列)十分匹配,因此被大量应用于自然语言处理研究中。下图为RNN的结构拆解:

RNN

1.2 关于LSTM(Long short-term memory)

RNN单个Cell的结构简单,因此也造成了梯度消失(Gradient Vanishing)问题,具体表现为RNN网络在序列较长时,在序列尾部已经基本丢失了序列首部的信息。为了克服这一问题,LSTM(Long short-term memory)被提出,通过门控机制(Gating Mechanism)来控制信息流在每个循环步中的留存和丢弃。下图为LSTM的结构拆解:

LSTM

本项目选择LSTM变种而不是经典的RNN做特征提取,可规避梯度消失问题,并获得更好的模型效果。
在MindSpore中nn.LSTM对应的公式:

h 0 : t , ( h t , c t ) = LSTM ( x 0 : t , ( h 0 , c 0 ) ) h_{0:t}, (h_t, c_t) = \text{LSTM}(x_{0:t}, (h_0, c_0)) h0:t,(ht,ct)=LSTM(x0:t,(h0,c0))

这里nn.LSTM隐藏了整个循环神经网络在序列时间步(Time step)上的循环,送入输入序列、初始状态,即可获得每个时间步的隐状态(hidden state`)拼接而成的矩阵,以及最后一个时间步对应的隐状态。我们使用最后的一个时间步的隐状态作为输入句子的编码特征,送入下一层

Time step:在循环神经网络计算的每一次循环,成为一个Time step。在送入文本序列时,一个Time step对应一个单词。因此在本例中,LSTM的输出 h 0 : t h_{0:t} h0:t对应每个单词的隐状态集合, h t h_t ht c t c_t ct对应最后一个单词对应的隐状态。

下一层:全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

1.3 特征提取网络构建

RNN循环神经网络: nn.LSTM()
初始化参数:

 embeddings:输入向量,hidden_dim:隐藏层特征的维度, output_dim:输出维数, n_layers:RNN 层的数量,bidirectional:是否为双向 RNN, pad_idx:padding_idx参数用于标记输入中的填充值(padding value)。在自然语言处理任务中,文本序列的长度不一致是非常常见的。为了能够对不同长度的文本序列进行批处理,我们通常会使用填充值对较短的序列进行填补。

tips:使用nn.embeddings()创建嵌入层时,可以通过padding_idx参数指定一个特定的索引,用于表示填充值。
embedding_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0),将padding_idx设置为0,表示使用索引为0的词汇作为填充值。在文本序列中,我们将使用0来填充较短的序列。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output

实例化模型,打印输出

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
print(model)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3266292.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

数据集成工具之kettle

Kettle 是一个用于数据集成的开源工具&#xff0c;由 Pentaho 开发&#xff0c;现已由 Hitachi Vantara 维护。Kettle 的全名是 Pentaho Data Integration (PDI)&#xff0c;主要用于数据提取、转换和加载&#xff08;ETL&#xff09;过程。 1. 核心组件 Spoon: 图形化的设计工…

Java | Leetcode Java题解之第283题移动零

题目&#xff1a; 题解&#xff1a; class Solution {public void moveZeroes(int[] nums) {int n nums.length, left 0, right 0;while (right < n) {if (nums[right] ! 0) {swap(nums, left, right);left;}right;}}public void swap(int[] nums, int left, int right)…

【Spring Boot教程:从入门到精通】掌握Spring Boot开发技巧与窍门(三)-配置git环境和项目创建

主要介绍了如何创建一个Springboot项目以及运行Springboot项目访问内部的html页面&#xff01;&#xff01;&#xff01; 文章目录 前言 配置git环境 创建项目 ​编辑 在SpringBoot中解决跨域问题 配置Vue 安装Nodejs 安装vue/cli 启动vue自带的图形化项目管理界面 总结 前言 …

k8s核心知识总结

写在前面 时间一下子到了7月份尾&#xff1b;整个7月份都乱糟糟的&#xff0c;不管怎么样&#xff0c;日子还是得过啊&#xff0c; 1、7月份核心了解个关于k8s&#xff0c;iceberg等相关技术&#xff0c;了解了相关的基础逻辑&#xff0c;虽然和数开主线有点偏&#xff0c;但是…

【b站-湖科大教书匠】5 运输层 - 计算机网络微课堂

课程地址&#xff1a;【计算机网络微课堂&#xff08;有字幕无背景音乐版&#xff09;】 https://www.bilibili.com/video/BV1c4411d7jb/?share_sourcecopy_web&vd_sourceb1cb921b73fe3808550eaf2224d1c155 目录 5 运输层 5.1 运输层概述 5.2 运输层端口号、复用与分用…

三维重建 概论

三维重建的方式 通俗来讲,三维重建就是将2D的数据生成3D的模型。 首先将2D的物体,通过各种方法,算法形成几何网格对象,同时用深度信息,处理远近,遮罩等关系,形成最终的3D模型。 在计算机视觉中,三维重建是指根据单视图或者多视图图像重建原始三维信息的过程。 单视图缺…

简单使用nginx

打开下载的nginx文件夹下的。。具体地址 打开并编辑nginx.conf文件 server {listen 8089;//访问端口号server_name localhost;//访问地址#charset koi8-r;#access_log logs/host.access.log main;location / {root D:/development/dist/;//dist包地址index index.h…

论文阅读:面向自动驾驶场景的多目标点云检测算法

论文地址:面向自动驾驶场景的多目标点云检测算法 概要 点云在自动驾驶系统中的三维目标检测是关键技术之一。目前主流的基于体素的无锚框检测算法通常采用复杂的二阶段修正模块,虽然在算法性能上有所提升,但往往伴随着较大的延迟。单阶段无锚框点云检测算法简化了检测流程,…

littlefs文件系统的移植和测试

简介 LittleFS 由ARM官方发布&#xff0c;ARM mbedOS的官方推荐文件系统&#xff0c;具有轻量级&#xff0c;掉电安全的特性。主要用在微控制器和flash上 掉电恢复&#xff0c;在写入时即使复位或者掉电也可以恢复到上一个正确的状态。 擦写均衡&#xff0c;有效延长flash的使…

基于JSP、java、Tomcat三者的项目实战--校园交易网(1)-项目搭建(前期准备工作)

这是项目的初始页面 接下来我先写下我的初始项目搭建 技术支持&#xff1a;JAVA、JSP 服务器&#xff1a;TOMCAT 7.0.86 编程软件&#xff1a;IntelliJ IDEA 2021.1.3 x64 首先我们打开页面&#xff0c;准备搭建项目的初始准备 1.New Project 2.随后点击Next&#xff0c;勾…

C++ | Leetcode C++题解之第287题寻找重复数

题目&#xff1a; 题解&#xff1a; class Solution { public:int findDuplicate(vector<int>& nums) {int slow 0, fast 0;do {slow nums[slow];fast nums[nums[fast]];} while (slow ! fast);slow 0;while (slow ! fast) {slow nums[slow];fast nums[fast]…

栈与递归

1.递归定义的数学函数 2.具有递归特性的数据结构 3.

vue3前端开发-小兔鲜项目-登录功能的业务接口调用

vue3前端开发-小兔鲜项目-登录功能的业务接口调用!这次&#xff0c;正式调用远程服务器的登录接口了。大家要必须使用测试账号密码&#xff0c;才能验证我们的代码。 测试账号密码是&#xff1a;账号&#xff08;xiaotuxian001&#xff09;&#xff1b;密码是&#xff08;1234…

Pytorch使用教学7-张量的广播

PyTorch中的张量具有和NumPy相同的广播特性&#xff0c;允许不同形状的张量之间进行计算。 广播的实质特性&#xff0c;其实是低维向量映射到高维之后&#xff0c;相同位置再进行相加。我们重点要学会的就是低维向量如何向高维向量进行映射。 相同形状的张量计算 虽然我们觉…

FreeRTOS操作系统(详细速通篇)——— 第六章

本专栏将对FreeRTOS进行快速讲解&#xff0c;带你了解并使用FreeRTOS的各部分内容。适用于快速了解FreeRTOS并进行开发、突击面试、对新手小白非常友好。期待您的后续关注和订阅&#xff01; 目录 系统中断管理 1 什么是中断&#xff1f; 1.1中断定义 1.2 中断执行机制 ​…

Chiplet SPI User Guide 详细解读

目录 一. 基本介绍 1.1.整体结构 1.2. 结构细节与功能描述 二. 输入输出接口 2.1. IO Ports for SPI Leader 2.2. IO Ports for SPI Follower 2.3. SPI Mode Configuration 2.4. Leader IP和Follower IP功能图 三. SPI Programming 3.1. Leader Register Descripti…

算法:数值算法

矩阵乘法 定义与性质 矩阵乘法是线性代数中的一个基本运算&#xff0c;它涉及到两个矩阵的点积运算。给定两个矩阵 A&#xff08;mn&#xff09;和 B&#xff08;np&#xff09;&#xff0c;它们的乘积 C&#xff08;mp&#xff09;定义为&#xff1a; 其中&#xff0c; Cij …

大连智点文化传媒有限公司介绍

在辽宁省大连市的文化传媒领域,大连智点文化传媒有限公司(以下简称“智点文化”)以其独特的魅力和专业的服务,逐渐崭露头角。作为一家集广告、文化、营销策划等多功能于一体的综合性文化传媒公司,智点文化不仅拥有深厚的行业底蕴,还不断探索与创新,以适应快速变化的市场需求。 …

在英特尔 Gaudi 2 上加速蛋白质语言模型 ProtST

引言 蛋白质语言模型 (Protein Language Models, PLM) 已成为蛋白质结构与功能预测及设计的有力工具。在 2023 年国际机器学习会议 (ICML) 上&#xff0c;MILA 和英特尔实验室联合发布了ProtST模型&#xff0c;该模型是个可基于文本提示设计蛋白质的多模态模型。此后&#xff0…

AI发展下的伦理挑战:构建未来科技的道德框架

一、引言 随着人工智能&#xff08;AI&#xff09;技术的飞速发展&#xff0c;我们正处在一个前所未有的科技变革时代。AI不仅在医疗、教育、金融、交通等领域展现出巨大的应用潜力&#xff0c;也在日常生活中扮演着越来越重要的角色。然而&#xff0c;这一技术的迅猛进步也带来…