TR2 - Transformer模型的复现


目录

  • 理论知识
  • 模型结构
    • 结构分解
      • 黑盒
      • 两大模块
      • 块级结构
      • 编码器的组成
      • 解码器的组成
  • 模型实现
    • 多头自注意力块
    • 前馈网络块
    • 位置编码
    • 编码器
    • 解码器
    • 组合模型
    • 最后附上引用部分
  • 模型效果
  • 总结与心得体会


理论知识

Transformer是可以用于Seq2Seq任务的一种模型,和Seq2Seq不冲突。

模型结构

模型整体结构

结构分解

黑盒

以机器翻译任务为例
黑盒

两大模块

在Transformer内部,可以分成Encoder编码器和和Decoder解码器两部分,这也是Seq2Seq的标准结构。
两大模块

块级结构

继续拆解,可以发现模型的由许多的编码器块和解码器块组成并且每个解码器都可以获取到最后一层编码器的输出以及上一层解码器的输出(第一个当然是例外的)。
块组成

编码器的组成

继续拆解,一个编码器是由一个自注意力块和一个前馈网络组成。
编码器的组成

解码器的组成

而解码器,是在编码器的结构中间又插入了一个Encoder-Decoder Attention层。
解码器的组成

模型实现

通过前面自顶向下的拆解,已经基本掌握了模型的总体结构。接下来自底向上的复现Transformer模型。

多头自注意力块

class MultiHeadAttention(nn.Module):"""多头注意力模块"""def __init__(self, dims, n_heads):"""dims: 每个词向量维度n_heads: 注意力头数"""super().__init__()self.dims = dimsself.n_heads = n_heads# 维度必需整除注意力头数assert dims % n_heads == 0# 定义Q矩阵self.w_Q = nn.Linear(dims, dims)# 定义K矩阵self.w_K = nn.Linear(dims, dims)# 定义V矩阵self.w_V = nn.Linear(dims, dims)self.fc = nn.Linear(dims, dims)# 缩放self.scale = torch.sqrt(torch.FloatTensor([dims//n_heads])).to(device)def forward(self, query, key, value, mask=None):batch_size = query.shape[0]# 例如: [32, 1024, 300] 计算10头注意力Q = self.w_Q(query)K = self.w_K(key)V = self.w_V(value)# [32, 1024, 300] -> [32, 1024, 10, 30] 把向量重新分组Q = Q.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)K = K.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)V = V.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)# 1. 计算QK/根dk# [32, 1024, 10, 30] * [32, 1024, 30, 10] -> [32, 1024, 10, 10] 交换最后两维实现乘法attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scaleif mask is not None:# 将需要mask掉的部分设置为很小的值attention = attention.masked_fill(mask==0, -1e10)# 2. softmaxattention = torch.softmax(attention, dim=-1)# 3. 与V相乘# [32, 1024, 10, 10] * [32, 1024, 10, 30] -> [32, 1024, 10, 30]x = torch.matmul(attention, V)# 恢复结构# 0 2 1 3 把 第2,3维交换回去x = x.permute(0, 2, 1, 3).contiguous()# [32, 1024, 10, 30] -> [32, 1024, 300]x = x.view(batch_size, -1, self.n_heads*(self.dims//self.n_heads))# 走一个全连接层x = self.fc(x)return x

前馈网络块

class FeedForward(nn.Module):"""前馈传播"""def __init__(self, d_model, d_ff, dropout=0.1):super().__init__()self.linear1 = nn.Linear(d_model, d_ff)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(d_ff, d_model)def forward(self, x):x = F.relu(self.linear1(x))x = self.dropout(x)x = self.linear2(x)return x

位置编码

class PositionalEncoding(nn.Module):"""位置编码"""def __init__(self, d_model, dropout=0.1, max_len=5000):super().__init__()self.dropout = nn.Dropout(dropout)# 用来存位置编码的向量pe = torch.zeros(max_len, d_model).to(device)# 准备位置信息position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2)* -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 注册一个不参数梯度下降的模型参数self.register_buffer('pe', pe)def forward(self, x):x  = x + self.pe[:, :x.size(1)].requires_grad_(False)return self.dropout(x)

编码器

class EncoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(d_model, n_heads)self.feedforward = FeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):attn_output = self.self_attn(x, x, x, mask)x = x + self.dropout(attn_output)x = self.norm1(x)ff_output = self.feedforward(x)x = x + self.dropout(ff_output)x = self.norm2(x)return x

解码器

class DecoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(d_model, n_heads)self.enc_attn = MultiHeadAttention(d_model, n_heads)self.feedforward = FeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, mask, enc_mask):# 自注意力attn_output = self.self_attn(x, x, x, mask)x = x + self.dropout(attn_output)x = self.norm1(x)# 编码器-解码器注意力attn_output = self.enc_attn(x, enc_output, enc_output, enc_mask)x = x + self.dropout(attn_output)x = self.norm2(x)# 前馈网络ff_output = self.feedforward(x)x = x + self.dropout(ff_output)x = self.norm3(x)return x

组合模型

class Transformer(nn.Module):def __init__(self, vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout=0.1):super().__init__()self.embedding = nn.Embedding(vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers)])self.fc_out = nn.Linear(d_model, vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, src, trg, src_mask, trg_mask):# 词嵌入src = self.embedding(src)src = self.positional_encoding(src)trg = self.embedding(trg)trg = self.positional_encoding(trg)# 编码器for layer in self.encoder_layers:src = layer(src, src_mask)# 解码器for layer in self.decoder_layers:trg = layer(trg, src, trg_mask, src_mask)output = self.fc_out(trg)return output

最后附上引用部分

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

模型效果

编写代码测试模型的复现是否正确(没有跑任务)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')vocab_size = 10000
d_model = 512
n_heads = 8
n_encoder_layers = 6
n_decoder_layers = 6
d_ff = 2048
dropout = 0.1transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout).to(device)src = torch.randint(0, vocab_size, (32, 10)).to(device) # 源语言
trg = torch.randint(0, vocab_size, (32, 20)).to(device) # 目标语言src_mask = (src != 0).unsqueeze(1).unsqueeze(2).to(device)
trg_mask = (trg != 0).unsqueeze(1).unsqueeze(2).to(device)output = transformer_model(src, trg, src_mask, trg_mask)
print(output.shape)

打印结果

torch.Size([32, 20, 10000])

说明模型正常运行了

总结与心得体会

我是从CV模型学到Transfromer来的,通过对Transformer模型的复现我发现:

  • 类似于残差的连接在Transformer中也十分常见,还有先缩小再放大的Bottleneck结构。
  • 整个Transformer模型的核心处理对特征的维度没有变化,这一点和CV模型完全不同。
  • Transformer的核心是多头自注意机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2906650.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

STL —— vector(1)

博主首页: 有趣的中国人 专栏首页: C专栏 本篇文章主要讲解vector使用的相关内容 1. vector简介 vector 是 C 标准库中的一个容器类模板,它提供了动态数组的功能,可以方便地管理和操作元素的集合。下面是关于 vector 的一些基本信…

NRF24L01P和SI24R1的区别

NRF24L01无线模块广泛地运用于:无线门禁、无线数据通讯、安防系统、遥控装置、遥感 勘测、智能运动设备、工业传感器;平常我们用到的无线鼠标基本上采用的都是NORDIC的N RF24L01无线模块方案,而且,只需要一个5号电池即可。 几年前…

HarmonyOS实战开发-如何实现一个自定义抽奖圆形转盘

介绍 本篇Codelab是基于画布组件、显式动画,实现的一个自定义抽奖圆形转盘。包含如下功能: 通过画布组件Canvas,画出抽奖圆形转盘。通过显式动画启动抽奖功能。通过自定义弹窗弹出抽中的奖品。 相关概念 Stack组件:堆叠容器&am…

详解TCP的三次握手和四次挥手

文章目录 1. TCP报文的头部结构2. 三次握手的原理与过程三次握手连接建立过程解析 3. 四次挥手的原理与过程四次挥手连接关闭过程的解析 4. 常见面试题 深入理解TCP连接:三次握手和四次挥手 在网络通信中,TCP(传输控制协议)扮演着…

人才推荐 | 材料化学博士,热衷于创新且可扩展的电池技术开发

编辑 / 木子 审核 / 朝阳 伟骅英才 伟骅英才致力于以大数据、区块链、AI人工智能等前沿技术打造开放的人力资本生态,用科技解决职业领域问题,提升行业数字化服务水平,提供创新型的产业与人才一体化服务的人力资源解决方案和示范平台&#x…

java多线程——概述,创建方式及常用方法

前言: 学习到多线程了,整理下笔记,daydayup!!! 多线程 什么是线程 线程(Thread)是一个程序内部的一条执行流程。若程序只有一条执行流程,那这个程序就是单线程的程序。 什么是多线程 多线程是指从软硬件上…

【AIGC】如何在Windows/Linux上部署stable diffusion

文章目录 整体安装步骤windows10安装stable diffusion环境要求安装步骤注意事项参考博客其他事项安装显卡驱动安装cuda卸载cuda安装对应版本pytorch安装git上的python包Q&A linux安装stable diffusion安装anaconda安装cudagit 加速配置虚拟环境挂载oss(optional…

传播力研究期刊投稿发表

《传播力研究》杂志是经国家新闻出版总署批准,黑龙江日报报业集团主管主办,面向全国公开发行的学术刊物。本刊为新闻、传媒、传播学类专业院校师生、文化传播理论研究者和从业人员及爱好者,开展学术交流与研讨,汲取当今业界新鲜的…

RGB,深度图,点云和体素的相互转换记录

目录 1.RGBD2Point 1.2 步骤 2.Point2Voxel-Voxelization 2.1 原理 2.2 代码 3.Voxel2Point 4.Point2RGB 5.Voxel2RGB 1.RGBD2Point input:RGB D 内外惨 output:points cloud def depth2pcd(depth_img):"""深度图转点云数据图…

翻译 《The Old New Thing》 - Why is a registry file called a “hive“?

Why is a registry file called a “hive“?https://devblogs.microsoft.com/oldnewthing/20030808-00/?p42943 为什么注册表文件被称为‘蜂巢’? Raymond Chen 2003年8月8日 分享一个没用的知识: 话说有一位 Windows NT 的开发者十分讨厌蜜蜂。于是&a…

FLV流媒体封装格式

1、FLV 简介 FLV(Flash Video) 是 Adobe 公司推出的一种流媒体格式,由于其封装后的音视频文件体积小、封装简单等特点,非常适合于互联网上使用。目前主流的视频网站基本都支持FLV。采用 FLV 格式封装的文件后缀为.flv。直播场景下拉流比较常见的是 http-…

计算机网络:现代通信的基石

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

小白了解Pinia第2集 · 三大核心状态Getters、Actions以及Plugins 插件

三大核心状态 state 第1集有详细讲解:https://blog.csdn.net/qq_51463650/article/details/137137080?spm1001.2014.3001.5501 getters Getter 完全等同于 Store 状态的 计算值。 它们可以用 defineStore() 中的 getters 属性定义。 他们接收“状态”作为第一个…

Elastic 8.13:Elastic AI 助手中 Amazon Bedrock 的正式发布 (GA) 用于可观测性

作者:来自 Elastic Brian Bergholm 今天,我们很高兴地宣布 Elastic 8.13 的正式发布。 有什么新特性? 8.13 版本的三个最重要的组件包括 Elastic AI 助手中 Amazon Bedrock 支持的正式发布 (general availability - GA),新的向量…

汽车电子行业知识:什么是智能驾驶辅助系统(ADAS)

文章目录 1. 什么是智能驾驶辅助系统(ADAS)1.1 ADAS的功能1.2 ADAS的优势1.3 未来发展趋势 2. ADAS等级2.1. 0级驾驶辅助2.2. 1级驾驶辅助2.3. 2级驾驶辅助2.4. 3级驾驶辅助2.5. 4级和5级驾驶辅助 3. 智能车4. ADAS供应商 1. 什么是智能驾驶辅助系统&…

文章分享:协和文章《病原宏基因组高通量测序性能确认方案》

摘要:宏基因组学利用新一代高通量测序技术,以特定环境下病原体基因组为研究对象,在分析病原体多样性、种群结构、进化关系的基础上,进一步探究病原体的群体功能活性、相互作用及其与环境之间的关系,发掘潜在的生物学意…

STM32之HAL开发——串口配置(CubeMX)

串口引脚初始化(CubeMX) 选择RCC时钟来源 选择时钟频率,配置为最高频率72MHZ 将单片机调试模式打开 SW模式 选择窗口一配置为异步通信模式 点击IO口设置页面,可以看到当前使用的串口一的引脚。如果想使用复用功能,只需…

每天五分钟深度学习:使用神经网络完成人脸的特征点检测

本文重点 我们上一节课程中学习了如何利用神经网络对图片中的对象进行定位,也就是通过输出四个参数值bx、by、bℎ和bw给出图片中对象的边界框。 本节课程我们学习特征点的检测,神经网络可以通过输出图片中对象的特征点的(x,y)坐标来实现对目标特征的识别,我们看几个例子。…

前端发版上线出现白屏问题

目录 路由配置问题资源缓存问题首屏加载过慢 :喂,你的页面白啦! 出现上线白屏的问题有很多,如:配置错误、缓存问题、浏览器兼容问题,根据不同情况去解决。 路由配置问题 问题描述: 在vue开发…

C语言中的联合体和枚举

联合体 联合体的创建 联合体的关键字是union union S {char a;int i; };除了关键字和结构体不一样之外,联合体的创建语法形式和结构体的很相似,如果不熟悉结构体的创建,可以看一下我上一篇的博客关于结构体知识的详解。 联合体的特点 联合…