MHD、MQA、GQA注意力机制详解

MHD、MQA、GQA注意力机制详解

  • 注意力机制详解及代码
    • 前言:
    • MHA
    • MQA
    • GQA

注意力机制详解及代码

前言:

自回归解码器推理是 Transformer 模型的 一个严重瓶颈,因为在每个解码步骤中加 载解码器权重以及所有注意键和值会产生 内存带宽开销

下图为三种注意力机制的结构图和实验结果

在这里插入图片描述

在这里插入图片描述

MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, hidden_size)self.v_linear = nn.Linear(hidden_size, hidden_size)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)## 对注意力输出进行拼接output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x):batch_size = x.size()[0]return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用了多查询注意力,如PaLM但许多语言模型没有,包括公开可用的语言模型,如T5和LLaM.

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=1,v=1)。相当于多个query,即多查询。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.head_dim) ###self.v_linear = nn.Linear(hidden_size, self.head_dim) ##### 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, 1)value = self.split_head(value, 1)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, head_num=None):batch_size = x.size()[0]if head_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)

GQA

  • 使用 5% 的原始预训练 计算将现有的多头语言模型检查点训 练到具有 MQA 的模型中
  • 引入分组查询注意力 (GQA),这是多 头语言模型的泛化。查询注意力,它使用中间,多于一个,少于查询头数量的键值头。
  • 经过训练的GQA 实现了接近多头注意力 的质量,并且速度与 MQA 相当。
  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=group_num,v=group_num)。相当于把多头分组了,比如原先有10个头,那就是10个query,分成5组,每组2个query,1个value,1个key。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 分组注意力查询
import torch
from torch import nn
class MutiGroupAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads, group_num):super(MutiGroupAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_headsself.group_num = group_num## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, self.group_num)value = self.split_head(value, self.group_num)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, group_num=None):batch_size,seq_len = x.size()[:2]if group_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3031531.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

AGV混合型电机驱动器|伺服控制器CNS-MI50H系列对电机的要求

混合型电机驱动器 CNS-MI50H系列涵盖CNS-MI50HB-A、CNS-MI50HBN-A、CNS-MI50HDN-A、CNS-MI50HSN-A型号,专为 AGV 舵轮控制需求设计,集成舵轮转向角度控制和驱动电机闭环控制。支持增量式编码器,霍尔传感器, 角度电位计&#xff0c…

UE 解决相同按键的按键事件只会执行一次的问题

在不同蓝图有同样按键的按键事件或者是同一个蓝图但是有很多个实例,默认都只会执行一次事件 因为: 勾选Consume Input:当这个选项被勾选时,意味着你的Actor会“消耗”这个输入事件,阻止它继续传播到其他可能也在监听相…

GEE错误——COPERNICUS/S2_SR_HARMONIZED(Level-2A)数据中不包含QA60波段解决方案(去云解决方案)

问题 我在屏蔽哨兵-2 协调图像集中有云层覆盖的像素时遇到了一个问题。云遮蔽功能是从 GEE 文档中获取的,因此运行正常。它使用的是 "QA60 "波段。 如果不屏蔽云层像素,图像就会出现在地图画布上: 但是,如果遮挡了多云像素,则不会显示图像: 原始代码 var se…

苹果电脑怎么安装crossover 如何在Mac系统中安装CrossOver CrossOver Mac软件安装说明

很多Mac的新用户在使用电脑的过程中,常常会遇到很多应用软件不兼容的情况。加上自己以前一直都是用Windows系统,总觉得Mac系统用得很难上手。 其实,用户可以在Mac上安装CrossOver,它支持用户在Mac上运行Windows软件,例…

TCP协议的确认应答机制

TCP(Transmission Control Protocol)是一种面向连接的、可靠的、基于字节流的传输层协议,它在网络通信中扮演着至关重要的角色。其中,确认应答机制是TCP协议中的一个核心概念,它确保了数据的可靠传输。本文将详细介绍J…

20240511每日运维----聊聊nignx改配置所有的nginx改完unknow

1、改配置所有的nginx改完unknow src/core/nginx.h src/http/ngx_http_header_filter_module.c src/http/ngx_http_special_response.c src/http/v2/ngx_http_v2_filter_module.c 2、make 3、去objs里面把nginx文件替换过去sbin/nginx

Xbox总裁:关闭游戏工作室 确保游戏质量与长期健康发展

易采游戏网5月11日消息,近日微软Xbox宣布关闭旗下四家知名游戏工作室,这一决策在游戏界引起了广泛关注。此举并非简单的资源调整,而是微软为确保旗下游戏质量,以及Xbox平台的长期健康发展所做出的重要战略部署。 被关闭的工作室包…

锐捷EWEB网管系统RCE漏洞

文章目录 免责声明漏洞描述漏洞原理影响版本漏洞复现修复建议 免责声明 该文章只为学习和交流,请不要做违法乱纪的事情,如有与本人无关 漏洞描述 锐捷网管系统是由北京锐捷数据时代科技有限公司开发的新一代基于云的网络管理软件,以"…

我国吻合器市场规模不断扩大 国产化率有所增长

我国吻合器市场规模不断扩大 国产化率有所增长 吻合器是替代手工切除或缝合的一种医疗器械,其工作原理与订书机十分相似,可利用钛钉对组织进行离断或吻合。经过多年发展,吻合器种类逐渐增多,根据手术方式不同,吻合器大…

【每日力扣】437. 路径总和 III 与105. 从前序与中序遍历序列构造二叉树

🔥 个人主页: 黑洞晓威 😀你不必等到非常厉害,才敢开始,你需要开始,才会变的非常厉害 437. 路径总和 III 给定一个二叉树的根节点 root ,和一个整数 targetSum ,求该二叉树里节点值之和等于 ta…

算法专题:位运算

目录 常见位运算总结 位运算相关算法题 1. 只出现一次的数字 2. 只出现一次的数字(|||) 3. 两整数之和 4. 只出现一次的数字(||) 常见位运算总结 在开始刷位运算这个类型的题目前,我想先带着大家学习一下一些常见…

2024年成都市企业技术中心认定申报条件要求、评价标准和时间

一、2024年成都市企业技术中心认定 (一)申报条件 1.在成都市行政区域内注册,具有独立法人资格。 2.已建立企业技术中心并正常运行1年以上。 3.有较强的经济、技术实力和较好的经济效益,在同…

Funkey游戏机新作,基于全志T113的全新版本

不同于配置高端、性能强劲的Windows、安卓掌机,有一部分的爱好者往往对拥有复古外形的开源掌机更加感兴趣。作为开源掌机的热门产品,小巧便携的FunKeys掌机是各位开源爱好者争相复刻的对象。因热爱开源掌机DIY而聚集的“双核掌机开发组”开发者团队&…

【python量化交易】qteasy使用教程05——创建第一个自定义交易策略

创建第一个自定义交易策略 使用qteasy创建自定义交易策略开始前的准备工作本节的目标自定义策略的实现方法使用 qteasy 的 Strategy 策略类三种不同的自定义策略基类定义一个双均线择时交易策略定义策略运行时机定义策略需要的数据自定义交易策略的实现:realize()获…

OpenGL入门第四步:摄像机视角变换与交互

OpenGL入门第一步:创建窗口、重写虚函数-CSDN博客 OpenGL入门第二步:颜色、纹理设置(解析)-CSDN博客 OpenGL入门第三步:矩阵变换、坐标系统-CSDN博客 目录 函数解析 具体代码 函数解析 相机视角变换需要与鼠标键盘进行交互,需要重写鼠标和键盘响应函数。 初始化 …

【Java】获取近六个月的年月

数据库里面存储的字段类型就是varchar&#xff0c;数据格式就是类似2024-12这样的年月格式。 目标&#xff1a; 以当前月份为标准&#xff0c;向前获取近6个月的年月&#xff08;year_month&#xff09;形成列表 // 获取近6个月的年月列表List<String> recentMonths ge…

java项目之相亲网站的设计与实现源码(springboot+mysql+vue)

风定落花生&#xff0c;歌声逐流水&#xff0c;大家好我是风歌&#xff0c;混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的相亲网站的设计与实现。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 相亲网站的设计与实…

Unable to locate the .NET SDK

问题描述&#xff1a; vs2019 加载项目时&#xff0c;提示如下&#xff1a; Unable to locate the .NET SDK as specified by global.json, please check that the specified version is installed. 项目中没有globan找al.json 文件 先使用&#xff1a; dotnet --list-sdks 命…

论文研读 Disentangled Information Bottleneck

解耦信息瓶颈 摘要&#xff1a; 信息瓶颈方法是一种从源随机变量中提取与预测目标随机变量相关的信息的技术&#xff0c;通常通过优化平衡压缩和预测项的IB拉格朗日乘子f来实现&#xff0c;然而拉格朗日乘子很难优化&#xff0c;需要多次实验来调整拉格朗日乘子的值&#xff0c…

mybatis 跨库查询 mysql

跨库&#xff0c;表关联的查询&#xff0c;实现起来很简单&#xff1a; select a.uid from ucenter.user a , database user_profile b where a.uid b.uid;只要在表的前边加上库名即可。 这个是我项目中xml 中的一个例子&#xff0c;项目采用的是springmvc,持久层框架就是my…