Transformer的PyTorch实现之若干问题探讨(二)

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的dec_inputs: [batch_size, tgt_len]enc_inpus: [batch_size, src_len]enc_outputs: [batch_size, src_len, d_model]'''dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵# 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为Falsedec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +dec_self_attn_subsequence_mask), 0).cuda()# 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]for layer in self.layers:dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]],[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]],[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的attn_mask: [batch_size, n_heads, seq_len, seq_len]'''# 1) 计算注意力分数QK^T/sqrt(d_k)scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]# 2)  进行 mask 和 softmax# mask为True的位置会被设为-1e9scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]# 3) 乘V得到最终的加权和context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]'''得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来'''return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2777030.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【前端高频面试题--TypeScript篇】

🚀 作者 :“码上有前” 🚀 文章简介 :前端高频面试题 🚀 欢迎小伙伴们 点赞👍、收藏⭐、留言💬 高频前端面试题--Vue3.0篇 什么是TypeScript?TypeScript数据类型TypeScript中命名空…

Python爬虫实战:抓取猫眼电影排行榜top100#4

爬虫专栏系列:http://t.csdnimg.cn/Oiun0 抓取猫眼电影排行 本节中,我们利用 requests 库和正则表达式来抓取猫眼电影 TOP100 的相关内容。requests 比 urllib 使用更加方便,而且目前我们还没有系统学习 HTML 解析库,所以这里就…

深度学习自然语言处理(NLP)模型BERT:从理论到Pytorch实战

文章目录 深度学习自然语言处理(NLP)模型BERT:从理论到Pytorch实战一、引言传统NLP技术概览规则和模式匹配基于统计的方法词嵌入和分布式表示循环神经网络(RNN)与长短时记忆网络(LSTM)Transform…

vim常用命令以及配置文件

layout: article title: “vim文本编译器” vim文本编辑器 有三种模式: 命令模式 文本模式, 末行模式 vim命令大全 - 知乎 (zhihu.com) 命令模式 插入 i: 切换到输入模式,在光标当前位置开始输入文本。 a: 进入插入模式,在光标下一个位置开始输入文…

Excel+VBA处理高斯光束

文章目录 1 图片导入与裁剪2 获取图片数据3 数据拟合 1 图片导入与裁剪 插入图片没什么好说的,新建Excel,【插入】->【图片】。 由于图像比较大,所以要对数据进行截取,选中图片之后,点击选项卡右端的【图片格式】…

前后端通讯:前端调用后端接口的五种方式,优劣势和场景

Hi,我是贝格前端工场,专注前端开发8年了,前端始终绕不开的一个话题就是如何和后端交换数据(通讯),本文先从最基础的通讯方式讲起。 一、什么是前后端通讯 前后端通讯(Frontend-Backend Commun…

HiveSQL——共同使用ip的用户检测问题【自关联问题】

注:参考文章: SQL 之共同使用ip用户检测问题【自关联问题】-HQL面试题48【拼多多面试题】_hive sql 自关联-CSDN博客文章浏览阅读810次。0 问题描述create table log( uid char(10), ip char(15), time timestamp);insert into log valuesinsert into l…

Java 学习和实践笔记(5)

三种类型的变量: Java中常量的定义: 下面的这个加号表示连接的意思,也就是把前面的字符串常量和后面的变量值在显示时连在一起: 显示效果如下: 如果没有用这个加号,就会报错:

037 稀疏数组

代码示例 /*** 生成稀疏数组* param arr 原数组* param defaultValue 数组默认值* return*/ static int[][] extractArray(int[][] arr, int defaultValue) {// 统计有多少个非默认值int count 0;for (int i 0; i < arr.length; i) {for (int j 0; j < arr[i].lengt…

2024年【高压电工】报名考试及高压电工操作证考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年高压电工报名考试为正在备考高压电工操作证的学员准备的理论考试专题&#xff0c;每个月更新的高压电工操作证考试祝您顺利通过高压电工考试。 1、【单选题】 高压电动机发生单相接地故障时,只要接地电流大于()…

无人机飞控算法原理基础研究,多旋翼无人机的飞行控制算法理论详解,无人机飞控软件架构设计

多旋翼无人机的飞行控制算法主要涉及到自动控制器、捷联式惯性导航系统、卡尔曼滤波算法和飞行控制PID算法等部分。 自动控制器是无人机飞行控制的核心部分&#xff0c;它负责接收来自无人机传感器和其他系统的信息&#xff0c;并根据预设的算法和逻辑&#xff0c;对无人机的姿…

MySQL 主键策略导致的效率性能

MySQL官方推荐不要使用uuid或者不连续不重复的雪花id(long形且唯一)&#xff0c;而是推荐连续自增的主键id&#xff0c;官方的推荐是auto_increment 一、准备三张表 分别是user_auto_key&#xff0c;user_uuid&#xff0c;user_random_key&#xff0c;分别表示自动增长的主键…

DolphinScheduler-3.2.0 集群搭建

本篇文章主要记录DolphinScheduler-3.2.0 集群部署流程。 注&#xff1a;参考文档&#xff1a; DolphinScheduler-3.2.0生产集群高可用搭建_dophinscheduler3.2.0 使用说明-CSDN博客文章浏览阅读1.1k次&#xff0c;点赞25次&#xff0c;收藏23次。DolphinScheduler-3.2.0生产…

从0开始图形学(光栅化)

前言 说起图形学&#xff0c;很多人就会提到OpenGL&#xff0c;但其实两者并不是同一个东西。引入了OpenGL加重了学习的难度和成本&#xff0c;使得一些原理并不直观。可能你知道向量&#xff0c;矩阵&#xff0c;纹理&#xff0c;重心坐标等概念&#xff0c;但就是不知道这些概…

正点原子-STM32通用定时器学习笔记(1)

目录 1. 通用定时器简介&#xff08;F1为例&#xff09; 2. 通用定时器框图 ①时钟源 ②控制器 ③时基单元 ④输入捕获 ⑤捕获/比较&#xff08;公共&#xff09; ⑥输出比较 3.时钟源配置 3.1 计数器时钟源寄存器设置方法 3.2 外部时钟模式1 3.3 外部时钟模式2 3…

第二章:三角面片及其填充

本文是《从0开始图形学》笔记的第二章&#xff0c;主要说明模型的一般构成以及如何查找模型的有效范围&#xff0c;涉及三角面片的填充以及向量的叉乘计算。 概念解说 上一节中&#xff0c;我们画出了箱子的顶点和边缘线&#xff0c;箱子还只是一个骨架而已。这一节我们来将箱…

再识C语言 DAY17 【什么是原码、反码和补码】

文章目录 前言本文总结于此文章 一、知识补充二、原码三、反码四&#xff0c;补码 总结如果您发现文章有错误请与我留言&#xff0c;感谢 前言 本文总结于此文章 一、知识补充 通常&#xff0c;1字节包含8位。C语言用字节&#xff08;byte&#xff09;表示储存系统字符集所需…

草莓CDMS原创内容分销系统,微信小说平台系统,附带系统搭建教程,搭建手册

草莓原创内容分销系统&#xff08;草莓CDMS&#xff09;——您的一站式内容分销解决方案 引领内容分销新潮流&#xff0c;草莓原创内容分销系统&#xff08;简称草莓CDMS&#xff09;以强大的技术支持和灵活的业务模式&#xff0c;为原创内容的传播和商业变现提供了前所未有的…

力扣刷题之旅:进阶篇(三)

力扣&#xff08;LeetCode&#xff09;是一个在线编程平台&#xff0c;主要用于帮助程序员提升算法和数据结构方面的能力。以下是一些力扣上的入门题目&#xff0c;以及它们的解题代码。 --点击进入刷题地址 一、动态规划&#xff08;DP&#xff09; 首先&#xff0c;让我们来…

口袋工具箱微信小程序源码

这是一款云开发口袋工具箱微信小程序源码&#xff0c;只有纯前端版本&#xff0c;该版本的口袋工具箱涵盖了13个功能&#xff0c;分别为圣诞帽头像生成、二维码生成、日语50音图、汉字拼音查询、计算器、程序员黄历、娱乐摇骰子、身材计算、所在地天气查询、IP地址查询、手机归…