GRU模块:nn.GRU层的介绍

       在 GRU(Gated Recurrent Unit)中,outputstate 都是由 GRU 层的循环计算产生的,它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。

GRU 的基本公式

GRU 的核心计算包括更新门(update gate)和重置门(reset gate),以及候选隐藏状态(candidate hidden state)。数学表达式如下:

  1. 更新门 \( z_t \): \[ z_t = \sigma(W_z \cdot h_{t-1} + U_z \cdot x_t) \]
       其中,\( \sigma \) 是sigmoid 函数,\( W_z \) 和 \( U_z \) 分别是对应于隐藏状态和输入的权重矩阵,\( h_{t-1} \) 是上一个时间步的隐藏状态,\( x_t \) 是当前时间步的输入。

  2. 重置门 \( r_t \):
       \[ r_t = \sigma(W_r \cdot h_{t-1} + U_r \cdot x_t) \]
       \( W_r \) 和 \( U_r \) 是更新门中定义的相似权重矩阵。

  3. 候选隐藏状态 \( \tilde{h}_t \):
       \[ \tilde{h}_t = \tanh(W \cdot r_t \odot h_{t-1} + U \cdot x_t) \]
       这里,\( \tanh \) 是激活函数,\( \odot \) 表示元素乘法(Hadamard product),\( W \) 和 \( U \) 是隐藏状态的权重矩阵。

  4. 最终隐藏状态 \( h_t \):
       \[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

output 和 state 的关系

  • output:在 GRU 中,output 包含了序列中每个时间步的隐藏状态。具体来说,对于每个时间步 \( t \),output 的第 \( t \) 个元素就是该时间步的隐藏状态 \( h_t \)。

  • state:state 是 GRU 层最后一层的隐藏状态,也就是 output 中最后一个时间步的隐藏状态 \( h_{T-1} \),其中 \( T \) 是序列的长度。

数学表达式

如果我们用 \( O \) 表示 output,\( S \) 表示 state,\( T \) 表示时间步的总数,那么:

\[ O = [h_0, h_1, ..., h_{T-1}] \]
\[ S = h_{T-1} \]

因此,state 实际上是 output 中最后一个元素,即 \( S = O[T-1] \)。

在 PyTorch 中,output 和 state 都是由 GRU 层的 `forward` 方法计算得到的。`output` 是一个三维张量,包含了序列中每个时间步的隐藏状态,而 `state` 是一个二维张量,仅包含最后一个时间步的隐藏状态。

代码示例

class Seq2SeqEncoder(d2l.Encoder):
"""⽤于序列到序列学习的循环神经⽹络编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)# 嵌⼊层self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,dropout=dropout)def forward(self, X, *args):# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X)# 在循环神经⽹络模型中,第⼀个轴对应于时间步X = X.permute(1, 0, 2)# 如果未提及状态,则默认为0output, state = self.rnn(X)# output的形状:(num_steps,batch_size,num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)return output, state

output:在完成所有时间步后,最后⼀层的隐状态的输出output是⼀个张量(output由编码器的循环层返回),其形状为(时间步数,批量⼤⼩,隐藏单元数)。

state:最后⼀个时间步的多层隐状态是state的形状是(隐藏层的数量,批量⼤⼩, 隐藏单元的数量)。

GRU的内部实现

上面一节的代码示例,是通过调用API实现的,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)。那么,GRU内部是如何实现的呢?

分为模型、模型参数初始化和隐状态初始化三个部分:

模型定义(模型定义与数学表示式一致,也可以参考上图):

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

  模型参数初始化(权重是从标准差0.01的高斯分布中提取的,超参数num_hiddens定义隐藏单元的数量,偏置项设置为0):

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three() # 更新⻔参数W_xr, W_hr, b_r = three() # 重置⻔参数W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)
# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

隐状态初始化函数(此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值都为0

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

最后由一个函数统一起来,实现模型:

model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)

       总体上说,内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。但是,如果需要深入理解GRU的话,那么内部实现的详细代码和计算公式就比较重要,中间的一些过程和变量的意义需要详细关注,只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等,所以,仔细研究这个模块的实现还是非常有必要的。对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3018289.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

git相关操作命令

配置 Git #查看配置信息 git config --global --listgit config --global user.name "Your Name" git config --global user.email "youremailexample.com" 初始化Git 仓库 git init git clone https://github.com/username/repository.gitgit remote -…

【JavaWeb】网上蛋糕项目商城-注册,登录,修改用户信息,提交订单

概念 通过以上多篇文章的讲解,对该项目的功能已经实现了很多,本文将对该项目的用户注册,登录,修改用户信息,以及用户添加至购物车的商品进行提交订单等功能的实现。 注册功能实现 点击head.jsp头部页面的注册按钮&a…

微信小程序 手机号授权登录

手机号授权登录 效果展示 这里面用的是 uni-app 官方的登录 他支持多端发布 https://zh.uniapp.dcloud.io/api/plugins/login.html#loginhttps://zh.uniapp.dcloud.io/api/plugins/login.html#login 下面是代码 <template><!-- 授权按钮 --><button v-if&quo…

3D人体展示仪

网址 https://3dbodyvisualizer.com/ 可以根据身高体重之类的在线生成人体的3D模型&#xff0c;感兴趣的可以试试

vue3+arco design通过动态表单方式实现自定义筛选

目录 1.说明 2.示例 3.运行截图 ​编辑 4.总结 1.说明 (1) 本文主要实现通过动态表单的方式实现自定义筛选的功能&#xff0c;用户可以自己添加筛选的项目&#xff0c;筛选条件及筛选内容。 (2) 每个项目的筛选包含筛选项目&#xff0c;筛选条件&#xff0c;筛选方式及筛选…

重学java 30.API 1.String字符串

于是&#xff0c;虚度的光阴换来了模糊 —— 24.5.8 一、String基础知识以及创建 1.String介绍 1.概述 String类代表字符串 2.特点 a.Java程序中的所有字符串字面值(如“abc”)都作为此类的实例(对象)实现 凡是带双引号的&#xff0c;都是String的对象 String s "abc&q…

在家中访问一个网站的思考

在家中访问一个网站的思考 1、家庭网络简介2、家庭WLAN DHCP2.1、家庭路由器PPPOE拨号2.2、DHCP&#xff08;动态主机配置协议&#xff09;2.3、接入家庭网的主机IP地址2.4、家庭总线型以太网2.5、Mac地址2.6、ARP协议2.7、IP协议 & UDP/TCP协议2.8、NAT&#xff08;Netwo…

【一起深度学习吧!!!!!】24/05/03

卷积层里的多输入输出通道 1、 多输入通道&#xff1a;代码演示&#xff1a; 多输出通道&#xff1a;代码实现&#xff1a; 1、 多输入通道&#xff1a; 当输入包含多个通道时&#xff0c;需要构造一个输入通道与之相等的卷积核&#xff0c;以便进行数据互相关计算。 例如李沐…

Ubuntu24.04安装中文输入法

Ubuntu24.04安装中文输入法 为了更好的体验&#xff0c;请访问个人博客 www.huerpu.cc:7000 一、添加中文语言支持 在安装中文输入法之前&#xff0c;首选要添加中文语言支持。选择System&#xff0c;点击Region & Language。 点击Manage Install Languages。 点击Insta…

repo跟git的关系

关于repo 大都讲的太复杂了,大多是从定义角度跟命令角度去讲解,其实从现实项目使用角度而言repo很好理解. 我们都知道git是用来管理项目的,多人开发过程中git功能很好用.现在我们知道一个项目会用一个git仓库去管理,项目的开发过程中会使用git创建分支之类的来更好的维护项目代…

css 文字描边

又是抄样式的一天。这次是百度地图。实现了问题和图形描边的效果。 代码&#xff1a; .BMap_scaleTxt.dark {color: #fff;text-shadow: -1px -1px 0 #000, 1px -1px 0 #000, -1px 1px 0 #000, 1px 1px 0 #000; } 效果&#xff1a;

安装numpy遇到的问题

安装numpy的时候提示无法安装如下&#xff1a; (venv) E:\works\AI\venv\Scripts>pip install numpy pandas matplotlib jupyter -i https://pypi.douban.com/simple Looking in indexes: https://pypi.douban.com/simple WARNING: Retrying (Retry(total4, connectNone, r…

分析师常用商业分析模型

一、背景 在用户调研中&#xff0c;我们发现分析师对商业分析模型的使用还是比较频繁。本文主要对用户调研结果中的分析师常用商业分析模型以及一些业界经典的商业分析模型进行分析&#xff0c;并梳理出执行落地流程&#xff0c;以此来指导分析师工具设计分析功能的引导性。 …

软件测试--接口测试

接口测试&#xff1a;直接对后端服务的测试&#xff0c;是服务端性能测试的基础 接口&#xff1a;系统之间数据交互的通道 接口测试&#xff1a;校验接口响应数据与预期数据是否一致

【JavaEE初阶系列】——Servlet运行原理以及Servlet API详解

目录 &#x1f6a9;Servlet运行原理 &#x1f6a9;Servlet API 详解 &#x1f393;HttpServlet核心方法 &#x1f393;HttpServletRequest核心方法 &#x1f388;核心方法的使用 &#x1f534;获取请求中的参数 &#x1f4bb;query string &#x1f4bb;直接通过form表…

回归分析的理解

1.是什么&#xff1a; 2.回归问题的求解&#xff1a; 首先是根据之前的数据确定变量和因变量的关系根据关系去预测目标数据根据结果做出判断 2.1如何找到关系&#xff1f; y’是根据模型生成的预测结果&#xff1a; y’axb&#xff0c;而我们的目的是y’和y(正确的结果)之间…

构造照亮世界——快速沃尔什变换 (FWT)

博客园 我的博客 快速沃尔什变换解决的卷积问题 快速沃尔什变换&#xff08;FWT&#xff09;是解决这样一类卷积问题&#xff1a; ci∑ij⊙kajbkc_i\sum_{ij\odot k}a_jb_k ci​ij⊙k∑​aj​bk​其中&#xff0c;⊙\odot⊙ 是位运算的一种。举个例子&#xff0c;给定数列 a,…

二叉搜索树相关

二叉搜索树 定义&#xff1a;对二叉搜索树的一些操作基本结构Insert操作Find操作Erase操作 InOrder遍历二叉树操作模拟字典模拟统计次数 定义&#xff1a; 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或者是具有以下性质的二叉树:若它的左子树不为空&a…

品鉴中的艺术表达:如何将红酒与绘画、雕塑等艺术形式相结合

品鉴雷盛红酒不仅是一种味觉的享受&#xff0c;更是一种艺术的体验。将雷盛红酒与绘画、雕塑等艺术形式相结合&#xff0c;能够创造出与众不同的审美体验&#xff0c;进一步丰富品鉴的内涵。 首先&#xff0c;绘画作为视觉艺术的一种表现形式&#xff0c;能够通过色彩和构图来传…

Linux:进程等待 进程替换

Linux&#xff1a;进程等待 & 进程替换 进程等待wait接口statuswaitpid接口 进程替换exec系列接口 当一个进程死亡后&#xff0c;会变成僵尸进程&#xff0c;此时进程的PCB被保留&#xff0c;等待父进程将该PCB回收。那么父进程要如何回收这个僵尸进程的PCB呢&#xff1f;父…