Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。

1.自定义数据中enc_input、dec_input及dec_output的区别

博文中给出了两对德语翻译成英语的例子:

# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [# enc_input   dec_input    dec_output['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]

初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。他们在transformer block中的位置如下:
在这里插入图片描述

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。

我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。

实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。

2.Transformer的训练流程

此处给出博文中附带的非常简洁的Transformer训练代码:

from torch import optim
from model import *model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练开始
for epoch in range(1000):for enc_inputs, dec_inputs, dec_outputs in loader:'''enc_inputs: [batch_size, src_len] [2,5]dec_inputs: [batch_size, tgt_len] [2,6]dec_outputs: [batch_size, tgt_len] [2,6]'''enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]loss = criterion(outputs, dec_outputs.view(-1))  # 将dec_outputs展平成一维张量# 更新权重optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')

这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:

ich mochte ein bier

在进行训练时,当我们给出起始符S,接下来应该预测出:

I

那训练时,有了SI后,则应该预测出

want

那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。

那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2775668.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

《PCI Express体系结构导读》随记 —— 第II篇 第4章 PCIe总线概述(10)

接前一篇文章:《PCI Express体系结构导读》随记 —— 第II篇 第4章 PCIe总线概述(9) 4.2 PCIe体系结构的组成部件 PCIe总线作为处理器系统的局部总线,其作用与PCI总线类似,主要目的是为了连接处理器系统中的外部设备&…

Vue源码系列讲解——虚拟DOM篇【二】(Vue中的DOM-Diff)

目录 1. 前言 2. patch 3. 创建节点 4. 删除节点 5. 更新节点 6. 总结 1. 前言 在上一篇文章介绍VNode的时候我们说了,VNode最大的用途就是在数据变化前后生成真实DOM对应的虚拟DOM节点,然后就可以对比新旧两份VNode,找出差异所在&…

MATLAB知识点:使用逻辑值修改或删除矩阵元素

​讲解视频:可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇(数学建模清风主讲,适合零基础同学观看)_哔哩哔哩_bilibili 节选自第3章 3.4.4 逻辑运算 3.4.4.3 使用逻辑值修改或删…

Elasticsearch(四)

是这样的前面的几篇笔记,感觉对我没有形成知识体系,感觉乱糟糟的,只是大概的了解了一些基础知识,仅此而已,而且对于这技术栈的学习也是为了在后面的java开发使用,但是这里的API学的感觉有点乱!然…

熔断机制解析:如何用Hystrix保障微服务的稳定性

微服务与系统的弹性设计 大家好,我是小黑,在讲Hystrix之前,咱们得先聊聊微服务架构。想象一下,你把一个大型应用拆成一堆小应用,每个都负责一部分功能,这就是微服务。这样做的好处是显而易见的,更新快,容错性强,每个服务可以独立部署,挺美的对吧?但是,问题也随之而…

PKI - 借助Nginx 实现Https 服务端单向认证、服务端客户端双向认证

文章目录 Openssl操系统默认的CA证书的公钥位置Nginx Https 自签证书Nginx Https 使用CA签发证书客户端使用自签证书供服务端验证客户端使用 根证书 签发客户端证书 供服务端验证 Openssl https://www.openssl.net.cn/ openssl是一个功能丰富且自包含的开源安全工具箱。 它提…

放假--寒假自学版 day1(补2.5)

fread 函数: 今日练习 C语言面试题5道~ 1. static 有什么用途?(请至少说明两种) 1) 限制变量的作用域 2) 设置变量的存储域 2. 引用与指针有什么区别? 1) 引用必须被初始化,指针不必。 2) 引用初始…

无心剑七绝《龙年大吉》

七绝龙年大吉 龙腾五岳九州圆 年吼佳音万里传 大漠苍鹰华夏梦 吉人天相铸奇缘 2024年2月8日 平水韵一先平韵 这首藏头七绝《龙年大吉》是无心剑为2024年春节所创作的诗作。2024年是农历的甲辰年,即龙年。在中国传统文化中,龙是吉祥的象征,代表…

PSM-Net根据Stereo图像生成depth图像

一、新建文件夹 在KITTI数据集下新建depth_0目录 二、激活anaconda环境 conda activate pt14py37三、修改submission.py文件 3.1 KITTI数据集路径 parser.add_argument(--datapath, default/home/njust/KITTI_DataSet/00/, helpselect model)3.2 深度图像输出路径 save…

【复现】九思OA系统 SQL注入漏洞_43

目录 一.概述 二 .漏洞影响 三.漏洞复现 1. 漏洞一: 四.修复建议: 五. 搜索语法: 六.免责声明 一.概述 九思软件自主研发的iThink协同OA办公自动化系统是面向中高端企业、政府机关和事业单位、等大型企业的协同办公软件,面…

pythn-scipy 查漏补缺

1. 2. 3. 4. 5. 6. 7. 8. 9. 偏度 skewness,峰度 kurtosis

TS学习与实践

文章目录 学习资料TypeScript 介绍TypeScript 是什么?TypeScript 增加了什么?TypeScript 开发环境搭建 基本类型编译选项类声明属性属性修饰符getter 与 setter方法static 静态方法实例方法 构造函数继承 与 super抽象类接口interface 定义接口implement…

Linux操作系统基础(二):Linux操作系统概述

文章目录 Linux操作系统概述 一、Linux起源 二、Linux 的含义 三、Linux发行版 Linux操作系统概述 一、Linux起源 Linux创始人——林纳斯 托瓦兹 Linux 诞生于1991年,作者上大学期间实现的 Linux的特点:开源、免费、拥有最为庞大的源码贡献者 …

nginx+flask+Gunicorn反代理服务拿不到真实IP的解决

背景 本人在宝塔linux环境,要部署flask的简单后端并且用Ngnix反代理,用Gunicorn框架部署。(o(╥﹏╥)o中间磕磕绊绊总算部署上去了,需要了解Gunicorn怎么部署的朋友,评论区留言,我加补一篇介绍)…

VuePress + Travis CI + Github Pages 全自动上线文档

整体思路 1.Github 创建项目,本地创建切换到 docs 分支,通过 VuePress 构建文档项目(写一些文档),上传至 Github。 2.Travis CI 自动 clone 后安装依赖、编译、上传至 Github master 分支。 3.通过 GitHub Pages 功…

[UI5 常用控件] 08.Wizard,NavContainer

文章目录 前言1. Wizard1.1 基本结构1.2 属性1.2.1 Wizard:complete1.2.2 Wizard:finishButtonText1.2.3 Wizard:currentStep1.2.4 Wizard:backgroundDesign1.2.5 Wizard:enableBranching1.2.6 WizardStep:…

【网络攻防实验】【北京航空航天大学】【实验一、入侵检测系统(Intrusion Detection System, IDS)实验】

实验一、入侵检测系统实验 1、 虚拟机准备 本次实验使用1台 Kali Linux 虚拟机和1台 Windows XP 虚拟机,虚拟化平台选择 Oracle VM VirtualBox,如下图所示。 2、 Snort环境搭建 实验前,先确保Kali Linux虚拟机能够访问外网,将网络模式设置为“网络地址转换”: 2.1 安装…

【flink状态管理(三)】StateBackend的整体设计、StateBackend创建说明

文章目录 一. 状态后端概述二. StateBackend的整体设计1. 核心功能2. StateBackend的UML3. 小结 三. StateBackend的加载与初始化1. StateBackend创建概述2. StateBackend创建过程 一. 状态后端概述 StateBackend作为状态存储后端,提供了创建和获取KeyedStateBacke…

真正免费的文件恢复软件easyrecovery2024中文版

easyrecovery数据恢复软件是一款广受好评的数据恢复工具,它能够有效地帮助用户恢复各种类型的文件。无论是照片、视频、音乐还是文档,都能轻松地找回这些重要文件。操作安全、用户可自主操作的数据恢复方案,它支持从各种各样的存储介质恢复删…

CTFshow web(命令执行 41-44)

web41 <?php /* # -*- coding: utf-8 -*- # Author: 羽 # Date: 2020-09-05 20:31:22 # Last Modified by: h1xa # Last Modified time: 2020-09-05 22:40:07 # email: 1341963450qq.com # link: https://ctf.show */ if(isset($_POST[c])){ $c $_POST[c]; if(!p…