深度学习论文: Attention is All You Need及其PyTorch实现

深度学习论文: Attention is All You Need及其PyTorch实现
Attention is All You Need
PDF:https://arxiv.org/abs/1706.03762.pdf
PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

大多数先进的神经序列转换模型采用编码器-解码器结构,其中编码器将输入符号序列转换为连续表示,解码器则基于这些表示逐个生成输出符号序列。在每个步骤中,模型采用自回归方式,将先前生成的符号作为额外输入来生成下一个符号。
在这里插入图片描述

1 Encoder and Decoder Stacks

1-1 Encoder 编码器

编码器采用N=6个结构相同的层堆叠而成。每一层包含两个子层,第一个是多头自注意力机制,第二个则是简单的位置全连接前馈网络。为提升性能,在每个子层之间引入了残差连接,并实施了层归一化。具体来说,每个子层的输出经过LayerNorm(x + Sublayer(x))计算得出,其中Sublayer(x)代表子层自身的功能实现。为了确保残差连接的顺畅进行,编码器中的所有子层以及嵌入层均生成维度为dmodel=512的输出。

def clones(module, N):"Produce N identical layers."return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])class LayerNorm(nn.Module):"Construct a layernorm module (See citation for details)."def __init__(self, features, eps=1e-6):super(LayerNorm, self).__init__()self.a_2 = nn.Parameter(torch.ones(features))self.b_2 = nn.Parameter(torch.zeros(features))self.eps = epsdef forward(self, x):mean = x.mean(-1, keepdim=True)std = x.std(-1, keepdim=True)return self.a_2 * (x - mean) / (std + self.eps) + self.b_2class Encoder(nn.Module):"Core encoder is a stack of N layers"def __init__(self, layer, N):super(Encoder, self).__init__()self.layers = clones(layer, N)self.norm = LayerNorm(layer.size)def forward(self, x, mask):"Pass the input (and mask) through each layer in turn."for layer in self.layers:x = layer(x, mask)return self.norm(x)class SublayerConnection(nn.Module):"""A residual connection followed by a layer norm.Note for code simplicity the norm is first as opposed to last."""def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):"Apply residual connection to any sublayer with the same size."return x + self.dropout(sublayer(self.norm(x)))class EncoderLayer(nn.Module):"Encoder is made up of self-attn and feed forward (defined below)"def __init__(self, size, self_attn, feed_forward, dropout):super(EncoderLayer, self).__init__()self.self_attn = self_attnself.feed_forward = feed_forwardself.sublayer = clones(SublayerConnection(size, dropout), 2)self.size = sizedef forward(self, x, mask):"Follow Figure 1 (left) for connections."x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))return self.sublayer[1](x, self.feed_forward)

1-2 Decoder解码器

解码器同样由N=6个结构相同的层堆叠而成。除了编码器层中的两个子层,解码器还额外引入了一个第三子层,专门用于对编码器堆叠的输出执行多头注意力机制。与编码器设计相似,解码器中的每个子层也采用残差连接与层归一化。此外,还对解码器堆叠中的自注意力子层进行了优化,确保在生成输出时,位置i的预测仅依赖于位置小于i的已知输出,这通过掩码和输出嵌入偏移一个位置的方式实现。

class Decoder(nn.Module):"Generic N layer decoder with masking."def __init__(self, layer, N):super(Decoder, self).__init__()self.layers = clones(layer, N)self.norm = LayerNorm(layer.size)def forward(self, x, memory, src_mask, tgt_mask):for layer in self.layers:x = layer(x, memory, src_mask, tgt_mask)return self.norm(x)class DecoderLayer(nn.Module):"Decoder is made of self-attn, src-attn, and feed forward (defined below)"def __init__(self, size, self_attn, src_attn, feed_forward, dropout):super(DecoderLayer, self).__init__()self.size = sizeself.self_attn = self_attnself.src_attn = src_attnself.feed_forward = feed_forwardself.sublayer = clones(SublayerConnection(size, dropout), 3)def forward(self, x, memory, src_mask, tgt_mask):"Follow Figure 1 (right) for connections."m = memoryx = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))return self.sublayer[2](x, self.feed_forward)

2 Attention

在这里插入图片描述
Transformer巧妙地运用了三种不同的多头注意力机制:

  • 在“编码器-解码器注意力”层中,查询源于前一解码器层,而键和值则汲取自编码器的输出。这种设计使得解码器中的任意位置都能关注输入序列的每一个细节,完美模拟了序列到序列模型中典型的编码器-解码器注意力机制。

  • 编码器内部则嵌入了自注意力层。在这个自注意力层中,键、值和查询均来自编码器前一层的输出,确保编码器中的每个位置都能全面捕捉到前一层中的信息。

  • 解码器同样采用了自注意力层,允许解码器中的每个位置关注到包括该位置在内的解码器内部所有位置的信息。为了确保自回归属性的保持,我们精心设计了一个机制:在缩放点积注意力内部,将对应非法连接的softmax输入值设为−∞,从而有效屏蔽了向左的信息流动。

2-1 Scaled Dot-Product Attention

Scaled Dot-Product Attention 缩放点积注意力,输入包括维度为 d k d_{k} dk的查询和键,以及维度为 d v d_{v} dv的值。计算查询与所有键的点积,每个都除以 d k \sqrt{d_{k} } dk ,然后应用softmax函数以获得值的权重。
在这里插入图片描述

def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attn

2-2 Multi-Head Attention

与其仅使用具有dmodel维度的键、值和查询来执行单一的注意力函数,将查询、键和值分别通过不同的、学习得到的线性投影,各自线性投影 h h h次,分别映射到 d k d_{k} dk d k d_{k} dk d v d_{v} dv维度后效果更好。随后在这些投影后的查询、键和值上并行地执行注意力函数,从而得到 d v d_{v} dv维度的输出值。最后,我们将这些输出值进行拼接,并再次通过投影得到最终的值。这种方法能够更充分地利用输入信息,提升模型的性能。
在这里插入图片描述
在这项研究中,我们采用了h=8个并行的注意力层,也称之为注意力头。对于每一个注意力头,我们都设定其维度为 d k d_{k} dk = d v d_{v} dv = d m o d e l / h d_{model} / h dmodel/h=64。由于每个注意力头的维度有所降低,因此其整体计算成本与使用完整维度的单头注意力相近。

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

3 Position-wise Feed-Forward Networks

除了注意力子层,编码器和解码器的每一层还包括一个全连接前馈网络,它独立且相同地应用于每个位置,包含两个线性变换和一个ReLU激活函数。
在这里插入图片描述
这里的输入和输出的维度是 d m o d e l = 512 d_{model}=512 dmodel=512,而内层的维度是 d f f = 2048 d_{ff}=2048 dff=2048

class PositionwiseFeedForward(nn.Module):"Implements FFN equation."def __init__(self, d_model, d_ff, dropout=0.1):super(PositionwiseFeedForward, self).__init__()self.w_1 = nn.Linear(d_model, d_ff)self.w_2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.w_2(self.dropout(F.relu(self.w_1(x))))

4 Embeddings and Softmax

与其他序列转换模型相似,这里使用学习嵌入将输入标记和输出标记转换为 d m o d e l d_{model} dmodel维向量。解码器输出通过线性变换和softmax函数转换为预测概率。在我们的模型中,嵌入层和softmax前的线性变换共享权重矩阵,并在嵌入层中将权重乘以 d m o d e l \sqrt{d_{model}} dmodel

class Embeddings(nn.Module):def __init__(self, d_model, vocab):super(Embeddings, self).__init__()self.lut = nn.Embedding(vocab, d_model)self.d_model = d_modeldef forward(self, x):return self.lut(x) * math.sqrt(self.d_model)

5 Positional Encoding

在编码器和解码器堆栈底部的输入嵌入中添加“位置编码”。位置编码的维度 d m o d e l d_{model} dmodel与嵌入的维度相同,以便可以将两者相加。位置编码有很多选择,包括学习和固定的。

在这里,选用不同频率的正弦和余弦函数:
在这里插入图片描述
其中pos是位置,i是维度。也就是说,位置编码的每个维度都对应于一个正弦波。波长从2π到10000·2π形成几何级数。选择这个函数是因为它可以使模型容易学习通过相对位置进行关注,因为对于任何固定的偏移量k, P E p o s + k PE_{pos+k} PEpos+k都可以表示为 P E p o s PE_{pos} PEpos的线性函数。

此外,在编码器和解码器堆栈中都对嵌入和位置编码的和使用了Dropout。对于基础模型,使用的Dropout ratio 为 P d r o p P_{drop} Pdrop=0.1。

class PositionalEncoding(nn.Module):"Implement the PE function."def __init__(self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# Compute the positional encodings once in log space.pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) *-(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)return self.dropout(x)

参考资料:
1 Attention is All You Need
2 The Illustrated Transformer
3 The Annotated Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2904858.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Codeforces Round 937 (Div. 4) A - F 题解

A. Stair, Peak, or Neither? 题解&#xff1a;直接比较输出即可。 代码&#xff1a; #include<bits/stdc.h> using namespace std ; typedef long long ll ; const int maxn 2e5 7 ; const int mod 1e9 7 ; inline ll read() {ll x 0, f 1 ;char c getchar()…

地方美食分享网站的设计与实现|Springboot+Vue.js+ Mysql+Java+ B/S结构(可运行源码+数据库+设计文档)

本项目包含可运行源码数据库LW&#xff0c;文末可获取本项目的所有资料。 推荐阅读100套最新项目持续更新中..... 2024年计算机毕业论文&#xff08;设计&#xff09;学生选题参考合集推荐收藏&#xff08;包含Springboot、jsp、ssmvue等技术项目合集&#xff09; 目录 1. …

Redis、Mysql双写情况下,如何保证数据一致

Redis、Mysql双写情况下&#xff0c;如何保证数据一致 场景谈谈数据一致性三个经典的缓存模式Cache-Aside Pattern读流程写流程 Read-Through/Write-Through&#xff08;读写穿透&#xff09;Write behind &#xff08;异步缓存写入&#xff09; 操作缓存的时候&#xff0c;删除…

MySQL count(*/column)查询优化

count()是SQL中一个常用的聚合函数&#xff0c;其被用来统计记录的总数&#xff0c;下面通过几个示例来说明此类查询的注意事项及应用技巧。 文章目录 一、count()的含义二、count()的应用技巧2.1 同时统计多列2.2 利用执行计划 一、count()的含义 count()用于统计符合条件的记…

Unity TMP 使用教程

文章目录 1 导入资源包2 字体制作3 表情包制作4 TMP 控件4.1 属性4.2 富文本标签 1 导入资源包 “Window -> TextMeshPro -> Import TMP Essential Resources”&#xff0c;导入完成后会创建一个名为"TextMehs Pro"的文件夹&#xff0c;这里面包含所需要的资源…

使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型&#xff0c;来了解GAN的基本原理。本片为此系列的第一篇&#xff0c;实现起来很简单&#xff0c;所以不要期待有很好的效果出来。 第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN&#xff0c;使用MINIST数据…

就业班 第二阶段 2401--3.27 day8 shell之循环控制

七、shell编程-循环结构 shell循环-for语句 for i in {取值范围} # for 关键字 i 变量名 in 关键字 取值范围格式 1 2 3 4 5 do # do 循环体的开始循环体 done # done 循环体的结束 #!/usr/bin/env bash # # Author: # Date: 2019/…

kubernetes K8s的监控系统Prometheus升级Grafana,来一个酷炫的Node监控界面(二)

上一篇文章《kubernetes K8s的监控系统Prometheus安装使用(一)》中使用的监控界面总感觉监控的节点数据太少&#xff0c;不能快算精准的判断出数据节点运行的状况。 今天我找一款非常酷炫的多维度数据监控界面&#xff0c;能够非常有把握的了解到各节点的数据&#xff0c;以及运…

HarmonyOS 应用开发之显式Want与隐式Want匹配规则

在启动目标应用组件时&#xff0c;会通过显式 Want 或者隐式 Want 进行目标应用组件的匹配&#xff0c;这里说的匹配规则就是调用方传入的 want 参数中设置的参数如何与目标应用组件声明的配置文件进行匹配。 显式Want匹配原理 显式 Want 匹配原理如下表所示。 名称类型匹配…

【leetcode】环形链表的约瑟夫问题

大家好&#xff0c;我是苏貝&#xff0c;本篇博客带大家刷题&#xff0c;如果你觉得我写的还不错的话&#xff0c;可以给我一个赞&#x1f44d;吗&#xff0c;感谢❤️ 点击查看题目 首先我们要明确一点&#xff0c;题目要求我们要用环形链表&#xff0c;所以用数组等是不被允…

某某消消乐增加步数漏洞分析

一、漏洞简介 1&#xff09; 漏洞所属游戏名及基本介绍&#xff1a;某某消消乐&#xff0c;三消游戏&#xff0c;类似爱消除。 2&#xff09; 漏洞对应游戏版本及平台&#xff1a;某某消消乐Android 1.22.22。 3&#xff09; 漏洞功能&#xff1a;增加游戏步数。 4&#xf…

Spark-Scala语言实战(6)

在之前的文章中&#xff0c;我们学习了如何在scala中定义与使用类和对象&#xff0c;并做了几道例题。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#xff0c;谢谢。 Spark-S…

智能设备配网保姆级教程

设备配网 简单来说&#xff0c;配网就是将物联网&#xff08;IoT&#xff09;设备连接并注册到云端&#xff0c;使其拥有与云端远程通信的能力。配网后&#xff0c;智能设备才能被手机应用或者项目管理后台控制&#xff0c;依托于智能场景创造价值。本文介绍了配网的相关知识&…

Linux环境安装Redis

Linux环境安装Redis 一&#xff0c;软件安装准备 服务器连接软件 Redis数据库连接软件 这是Windows软件&#xff0c;用于连接Linux服务器使用。推荐使用。 二&#xff0c;下载Redis 下载地址&#xff1a;Index of /releases/ 截止编稿Redis版本已经到7.2.4了&#xff0c;如果…

如何使用Windows电脑部署Lychee私有图床网站并实现无公网IP远程管理本地图片

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-MSVdVLkQMnY9Y2HW {font-family:"trebuchet ms",verdana,arial,sans-serif;f…

什么是RISC-V?开源 ISA 如何重塑未来的处理器设计

RISC-V代表了处理器架构的范式转变&#xff0c;特点是其开源模型简化了设计理念并促进了全球community-driven的开发。RISC-V导致了处理器技术发展前进方式的重大转变&#xff0c;提供了一个不受传统复杂性阻碍的全新视角。 RISC-V起源于加州大学伯克利分校的学术起点&#xff…

腾讯云服务器多少钱一年?2024年最新价格整理

2024年腾讯云4核8G服务器租用优惠价格&#xff1a;轻量应用服务器4核8G12M带宽646元15个月&#xff0c;CVM云服务器S5实例优惠价格1437.24元买一年送3个月&#xff0c;腾讯云4核8G服务器活动页面 txybk.com/go/txy 活动链接打开如下图&#xff1a; 腾讯云4核8G服务器优惠价格 轻…

设计模式 - 简单工厂模式

文章目录 前言 大家好,今天给大家介绍一下23种常见设计模式中的一种 - 工厂模式 1 . 问题引入 请用C、Java、C#或 VB.NET任意一种面向对象语言实现一个计算器控制台程序&#xff0c;要求输入两个数和运算符 号&#xff0c;得到结果。 下面的代码实现默认认为两个操作数为Inte…

阿里云CentOS7安装Hadoop3伪分布式

ECS准备 开通阿里云ECS 略 控制台设置密码 连接ECS 远程连接工具连接阿里云ECS实例&#xff0c;这里远程连接工具使用xshell 根据提示接受密钥 根据提示写用户名和密码 用户名&#xff1a;root 密码&#xff1a;在控制台设置的密码 修改主机名 将主机名从localhost改为需要…

excel中批量插入分页符

excel中批量插入分页符&#xff0c;实现按班级打印学生名单。 1、把学生按照学号、班级排序好。 2、选择班级一列&#xff0c;点击数据-分类汇总。汇总方式选择计数&#xff0c;最后三个全部勾选。汇总结果一定要显示在数据的下发&#xff0c;如果显示在上方&#xff0c;后期…