Transformer self-attention源码及原理理解

 自注意力计算公式:

  • 在公式(1)中Q(query)是输入一个序列中的一个token,K(key)代表序列中所有token的特征。
  •  QK^{T}可以得到当前token与序列中其他token的相关性。
  • 在论文原文中d_{model}=512,表示每个token用512维特征表示(序列符号的embedding长度)。 d_{k}=d_{model}\div h=64,表示每个头的大小为64。

自注意力机制的pytorch实现:

def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attn

多头注意力机制的pytorch实现如下:

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]#这段代码首先使用zip函数,将self.linears和(query, key, value)这两个列表打包成一个元组列表,其中每个元组包含一个线性层对象和一个输入张量#对遍历的每一个Linear层,对query key value分别计算,结果放在query key value中输出# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

W_{i}^{Q} \quad W_{i}^{Q} \quad W_{i}^{V}对应Figure2中的三个Linear层的权重,通过训练可得,它们的形状是(需要从代码理解),用来将原始的Q K V投影到下一层做Dot-Production attention计算。

首先Q K V怎么来的?

 和输入序列的token有关

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2870324.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

子组件自定义事件$emit实现新页面弹窗关闭之后父界面刷新

文章目录 需求弹窗关闭之后父界面刷新展示最新数据 实现方案AVUE 大文本默认展开slotVUE 自定义事件实现 父界面刷新那么如何用呢? 思路核心代码1. 事件定义2. 帕斯卡命名组件且在父组件中引入以及注册3. 子组件被引用与父事件监听4.父组件回调函数 5.按钮弹窗事件 需求 弹窗…

【图像分割】使用Otsu 算法及迭代计算最佳全局阈值估计并实现图像分割(代码实现与分析)

本实验要求理解全局阈值分割的概念,并实现文本图像分割。需要大家深入理解Ostu 算法的实现过程及其迭代原理,同时通过学习使用Otsu 算法及其迭代,实践图像分割技术在文本图像处理中的应用。 以下将从实验原理、实验实现、实验结果分析三部分对…

短剧分销怎么赚钱的?保姆级教程助你短剧cps推广赚大钱

短剧分销怎么赚钱的?小白也能月入过万/“蜂小推“保姆级教程,助你短剧分销赚大钱! 相信大家或多或少都在某些群里看到一些“霸道总裁爱上职场小菜鸟...”“这类链接,无利不起早,为什么会有那么多在群里分享这些狗血视…

紧抓需求,把脉市场,方太高端全场景厨电创造厨居新范式

撰稿 | 多客 来源 | 贝多财经 随着“中国制造”向“中国智造”方向转变,厨电不再是单一的工具设施,而是现代化厨居生活的映射,承担着沟通连接人、家庭与社会的桥梁作用。烹饪全场景下智能高效技术、整体美学设计、品类联动能力成为厨电品牌…

【机器学习系列】M3DM工业缺陷检测部署与训练

一.基础资料 1.Git 地址 地址 2.issues issues 3.参考 参考 csdn 二.服务器信息 1.GPU 服务器 GPU 服务器自带 CUDA 安装(前提是需要勾选上)CUDA 需要选择大于 11.3 的版本登录服务器后会自动安装 GPU 驱动 2.CUDA 安装 GPU 服务器自带 CUDA CUDA 版本查看 3.登录信…

从政府工作报告探计算机行业发展——探索计算机行业发展蓝图

目录 前言 一、政策导向与行业发展 (一)政策导向的影响 (二)企业如何把握政策机遇推动创新发展 二、技术创新与产业升级 三、数字经济与数字化转型 四、国际合作与竞争态势 五、行业人才培养与科技创新 (一&a…

【linux】搜索所有目录和子目录下的包含.git的文件并删除

一、linux命令搜索所有目录和子目录下的包含.git的文件 在Linux系统中,要搜索所有目录和子目录下的包含.git的文件,可以使用find命令。find命令允许指定路径、表达式和操作来查找文件。 以下是使用find命令搜索包含.git的文件的方法: 1. 基…

ideaSSM社区二手交易平台C2C模式开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 idea ssm 社区二手交易平台系统是一套完善的完整信息管理系统,结合SSM框架完成本系统SpringMVC spring mybatis ,对理解JSP java编程开发语言有帮助系统采用SSM框架(MVC模式开发),系统具有完整的源代码…

Ubuntu 22.04 Nvidia Audio2Face Error:Failed to build TensorRT engine

背景 1.在Ubuntu22.04上安装Audio2Face后启动,嘴形不会实时同步。控制台显示如【图一】: 【图一】 2.log日志如下: Error: Error during running command: [‘/home/admin/omniverse/libs/deps/321b626abba810c3f8d1dd4d247d2967/exts/omni.audio2fac…

全国农产品价格分析预测可视化系统设计与实现

全国农产品价格分析预测可视化系统设计与实现 【摘要】在当今信息化社会,数据的可视化已成为决策和分析的重要工具。尤其是在农业领域,了解和预测农产品价格趋势对于农民、政府和相关企业都至关重要。为了满足这一需求,设计并实现了全国农产…

C++中的using关键字

1. 类型别名 using关键字可以用来为类型创建一个新的名字,这在代码的可读性和维护性方面非常有帮助。 // 定义类型别名 using IntPtr int*;// 使用 int value 5; IntPtr ptr &value;2. 命名空间别名 如果你正在使用一个非常长的命名空间,可以使…

浅谈HTTP 和 HTTPS (中间人问题)

前言 由于之前的文章已经介绍过了HTTP , 这篇文章介绍 HTTPS 相对于 HTTP 做出的改进 开门见山: HTTPS 是对 HTTP 的加强版 主要是对一些关键信息 进行了加密 一.两种加密方式 1.对称加密 公钥 明文 密文 密文 公钥 明文 2.非对称加密 举个例子就好比 小区邮箱 提供一…

【S5PV210】 | 按键和CPU的中断系统

S5PV210 | 按键和CPU的中断系统 时间:2024年3月17日14:04:27 目录 [TOC] 1.参考 1.项目管理 2.x210bv3s: ARM Cortex-A8 (s5pv210)的开发与学习 硬件版本:(九鼎)X210BV3S 20160513 3.知识星球 | 深度连接…

基于SSM开发网上电子购物商城系统

开发工具:EclipseJdkTomcatMySQL数据库 效果视频: 链接: https://pan.baidu.com/s/1qLB1UKQV42t0TNNJRQZd7Q 提取码: g5xg

C语言例:设 int a=11; 则表达式 a+=a-=a*a 的值

注&#xff1a;软件为VC6.0 代码如下&#xff1a; #include<stdio.h> int main(void) {int a11, b;b (aa-a*a); //a*a121 -->a-121结果为a-110 -->a-110结果为a-220printf("表达式aa-a*a 的值为&#xff1a; %d\n",b);return 0; } //优先级&#x…

sparksql简介

什么是sparksql sparksql是一个用来处理结构话数据的spark模块&#xff0c;它允许开发者便捷地使用sql语句的方式来处理数据&#xff1b;它是用来处理大规模结构化数据的分布式计算引擎&#xff0c;其他分布式计算引擎比较火的还有hive&#xff0c;map-reduce方式。 sparksql…

sqllab第二十七A关通关笔记

知识点&#xff1a; 双引号闭合union select 大小写绕过 Union Select这里不能进行错误注入&#xff0c;无回显 经过测试发现这是一个双引号闭合 构造payload:id1"%09and%091"1 页面成功回显 构造payload:id0"%09uNion%09SElect%091,2,3%09"1 页面成功…

简单高效多语言请求的主流电商平台API数据采集实时接口如何采集数据

电商数据采集API功能概述&#xff1a; 1. 实时采集&#xff1a;1688采集能够自动从阿里巴巴和1688网站抓取商品信息&#xff0c;无需人工手动搜索&#xff0c;节省大量时间。 2. 商品筛选&#xff1a;用户可以根据需求设置采集条件&#xff0c;如价格、销量、信用度等&#x…

JDBC编程(Mysql)

目录 1.什么是jdbc 2.使用 2.1下载mysql数据库驱动 2.2导入项目 2.3编写代码 2.3.1数据源 2.3.2和数据库服务器建立连接 2.3.3构建一个操作数据库的sql语句 2.3.4执行sql 2.3.5释放前面创建的各种资源 2.3.6运行java程序 2.4其他操作 2.4.1修改操作 2.4.2删除操作…

基于Springboot和Redis实现的快递代取系统

1.项目简介 本项目基于springboot框架开发而成&#xff0c;前端采用bootstrap和layer框架开发&#xff0c;系统功能完整&#xff0c;界面简洁大方&#xff0c;比较适合做毕业设计使用。 本项目主要实现了代取快递的信息管理功能&#xff0c;使用角色有三类&#xff1a;一是客…