人工智能算法工程师(中级)课程20-模型注意力机制之注意力机制的原理、计算方式与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程20-模型注意力机制之注意力机制的原理、计算方式与代码详解。本文深入探讨了注意力机制在深度学习中的应用与原理,尤其聚焦于序列到序列模型的上下文中。通过直观的解释和详细的代码示例,文章展示了如何使用PyTorch框架实现注意力机制,包括普通注意力和计算,自注意力和计算,多头注意力和计算,通道注意力,空间注意力,混合注意力。从理论到实践,逐步解析了注意力机制如何帮助模型更好地理解输入序列的关键部分,从而提高预测准确性。适合对深度学习和自然语言处理感兴趣的读者,尤其是希望深入了解并实际操作注意力机制的开发人员。

文章目录

  • 一、引言
  • 二、普通注意力与计算
    • 1. 注意力的原理
    • 2. PyTorch代码实现
  • 三、自注意力与计算
    • 1. 自注意力的原理
    • 2. PyTorch代码实现
  • 四、多头注意力与计算
    • 1. 多头注意力的原理
  • 五、通道注意力
    • 1. 通道注意力的原理
    • 2. PyTorch代码实现
  • 六、空间注意力
    • 1. 空间注意力的原理
    • 2. PyTorch代码实现
  • 七、混合注意力
    • 1. 混合注意力的原理
    • 2. PyTorch代码实现
  • 八、总结

一、引言

近年来,注意力机制在深度学习领域取得了显著的成果,尤其是在自然语言处理、计算机视觉等领域。本文将详细介绍注意力机制的不同类型,包括普通注意力、自注意力、多头注意力、通道注意力、空间注意力和混合注意力,并阐述其数学原理,最后提供基于PyTorch的完整可运行代码。

二、普通注意力与计算

1. 注意力的原理

普通注意力机制的核心思想是:为输入序列中的每个元素分配一个权重,权重的大小表示该元素对输出结果的重要性。
假设输入序列为 X = { x 1 , x 2 , … , x n } X = \{x_1, x_2, \ldots, x_n\} X={x1,x2,,xn},注意力权重为 a = { a 1 , a 2 , … , a n } a = \{a_1, a_2, \ldots, a_n\} a={a1,a2,,an},则输出为:
O = ∑ i = 1 n a i x i O = \sum_{i=1}^{n} a_i x_i O=i=1naixi
其中, a i a_i ai 的计算方式为:
a i = exp ⁡ ( e i ) ∑ j = 1 n exp ⁡ ( e j ) a_i = \frac{\exp(e_i)}{\sum_{j=1}^{n} \exp(e_j)} ai=j=1nexp(ej)exp(ei)
e i = score ( x i , h ) e_i = \text{score}(x_i, h) ei=score(xi,h)
score \text{score} score 函数用于计算输入元素 x i x_i xi 与隐藏状态 h h h 的匹配程度。

2. PyTorch代码实现

import torch
import torch.nn as nn
class Attention(nn.Module):def __init__(self):super(Attention, self).__init__()def forward(self, x, h):e = torch.matmul(x, h.unsqueeze(2)).squeeze(2)a = torch.softmax(e, dim=1)o = torch.sum(a.unsqueeze(2) * x, dim=1)return o
x = torch.randn(10, 5)  # 假设输入序列长度为10,特征维度为5
h = torch.randn(5)      # 隐藏状态维度为5
attention = Attention()
output = attention(x, h)
print(output)

三、自注意力与计算

1. 自注意力的原理

自注意力(Self-Attention)机制允许输入序列中的每个元素都与其他元素进行关联,从而更好地捕捉序列内部的依赖关系。
自注意力计算过程如下:
Q = X W Q , K = X W K , V = X W V Q = XW^Q, K = XW^K, V = XW^V Q=XWQ,K=XWK,V=XWV
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V
其中, W Q , W K , W V W^Q, W^K, W^V WQ,WK,WV 是可学习的权重矩阵, d k d_k dk K K K 的维度。
在这里插入图片描述

2. PyTorch代码实现

class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsself.query_linear = nn.Linear(self.head_dim, self.head_dim)self.key_linear = nn.Linear(self.head_dim, self.head_dim)self.value_linear = nn.Linear(self.head_dim, self.head_dim)def forward(self, x):batch_size = x.shape[0]x = x.reshape(batch_size, -1, self.heads, self.head_dim)query = self.query_linear(x)key = self.key_linear(x)value = self.value_linear(x)attention = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)attention = torch.softmax(attention, dim=-1)out = torch.matmul(attention, value)out = out.reshape(batch_size, -1, self.embed_size)return out
x = torch.randn(10, 32, 512)  # 假设输入序列长度为10,特征维度为512,head数为32
self_attention = SelfAttention(embed_size=512, heads=32)
output = self_attention(x)
print(output)

四、多头注意力与计算

1. 多头注意力的原理

多头注意力(Multi-Head Attention)机制是将自注意力机制分解为多个头,每个头关注不同的信息,最后将多个头的输出拼接起来。
MultiHead ( Q , K , V ) = Concat ( head 1 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,,headh)WO
其中, head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV) W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V WiQ,WiK,WiV 是第 i i i 个头的权重矩阵。
在这里插入图片描述

  1. PyTorch代码实现方式
class MultiHeadAttention(nn.Module):def __init__(self, embed_size, heads):super(MultiHeadAttention, self).__init__()self.self_attention = SelfAttention(embed_size, heads)self.linear = nn.Linear(embed_size, embed_size)def forward(self, query, key, value):attention = self.self_attention(query, key, value)out = self.linear(attention)return out
x = torch.randn(10, 32, 512)  # 假设输入序列长度为10,特征维度为512,head数为32
multi_head_attention = MultiHeadAttention(embed_size=512, heads=32)
output = multi_head_attention(x, x, x)
print(output)

五、通道注意力

1. 通道注意力的原理

通道注意力(Channel Attention)机制是在特征图的通道维度上进行注意力加权,强调重要的特征通道。
通道注意力的一般计算方式为:
ChannelAttention ( F ) = σ ( MLP ( AvgPool ( F ) ) + MLP ( MaxPool ( F ) ) ) \text{ChannelAttention}(F) = \sigma(\text{MLP}(\text{AvgPool}(F)) + \text{MLP}(\text{MaxPool}(F))) ChannelAttention(F)=σ(MLP(AvgPool(F))+MLP(MaxPool(F)))
其中, F F F 是输入特征图, AvgPool \text{AvgPool} AvgPool MaxPool \text{MaxPool} MaxPool 分别表示平均池化和最大池化, MLP \text{MLP} MLP 表示多层感知机, σ \sigma σ 表示sigmoid激活函数。
在这里插入图片描述

2. PyTorch代码实现

class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False),nn.ReLU(inplace=True),nn.Linear(in_channels // ratio, in_channels, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x).view(x.size(0), -1))max_out = self.fc(self.max_pool(x).view(x.size(0), -1))out = self.sigmoid(avg_out + max_out)return out * x
x = torch.randn(32, 512, 7, 7)  # 假设输入特征图的形状为(32, 512, 7, 7)
channel_attention = ChannelAttention(in_channels=512)
output = channel_attention(x)
print(output)

六、空间注意力

1. 空间注意力的原理

空间注意力(Spatial Attention)机制是在特征图的空间维度上进行注意力加权,强调重要的空间位置。
空间注意力的计算方式为:
SpatialAttention ( F ) = σ ( conv ( Concat ( AvgPool ( F ) , MaxPool ( F ) ) ) ) \text{SpatialAttention}(F) = \sigma(\text{conv}(\text{Concat}(\text{AvgPool}(F), \text{MaxPool}(F)))) SpatialAttention(F)=σ(conv(Concat(AvgPool(F),MaxPool(F))))
其中, F F F 是输入特征图, conv \text{conv} conv 表示卷积操作。
在这里插入图片描述

2. PyTorch代码实现

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size % 2 == 1, "Kernel size must be odd"self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x) * x
x = torch.randn(32, 512, 7, 7)  # 假设输入特征图的形状为(32, 512, 7, 7)
spatial_attention = SpatialAttention(kernel_size=7)
output = spatial_attention(x)
print(output)

七、混合注意力

1. 混合注意力的原理

混合注意力(Hybrid Attention)机制是将通道注意力与空间注意力结合使用,从而同时捕捉特征通道和空间位置的重要性。
混合注意力的计算方式为:
HybridAttention ( F ) = SpatialAttention ( ChannelAttention ( F ) ) \text{HybridAttention}(F) = \text{SpatialAttention}(\text{ChannelAttention}(F)) HybridAttention(F)=SpatialAttention(ChannelAttention(F))

2. PyTorch代码实现

class HybridAttention(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super(HybridAttention, self).__init__()self.channel_attention = ChannelAttention(in_channels, ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):x = self.channel_attention(x)x = self.spatial_attention(x)return x
x = torch.randn(32, 512, 7, 7)  # 假设输入特征图的形状为(32, 512, 7, 7)
hybrid_attention = HybridAttention(in_channels=512)
output = hybrid_attention(x)
print(output)

八、总结

本文详细介绍了注意力机制的不同类型,包括普通注意力、自注意力、多头注意力、通道注意力、空间注意力和混合注意力,并提供了每种注意力的数学原理和基于PyTorch的代码实现。通过这些示例,我们可以看到注意力机制在深度学习模型中的重要作用,它能够有效地提高模型对关键信息的捕捉能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3267132.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

48 mysql 全局变量修改了时区, 客户端拿到的依然是旧时区

前言 这是一个 我们最近碰到的问题 在我们的一个 服务平台 查询到的时间字段 比 当前时区的当前时间多 8 小时 然后 这个问题 也是挺神奇的, navicate 上面查询到的 时间是在正常的时间 然后 查询环境变量 tz_zone 是 “08:00”, 也没有问题, 但是 客户端这边 拿到的是 当…

【HTML+CSS】HTML超链接:构建网页导航的基石

目录 什么是HTML超链接? 基本语法 示例 链接到另一个网页 链接到同一页面内的不同部分 常用属性 在Web开发的广阔世界中,HTML(HyperText Markup Language)作为网页内容的标准标记语言,扮演着至关重要的角色。而在…

nodejs安装及环境配置轨道交通运维检测系统App-OA人事办公排班故障维修

✌网站介绍:✌10年项目辅导经验、专注于计算机技术领域学生项目实战辅导。 ✌服务范围:Java(SpringBoo/SSM)、Python、PHP、Nodejs、爬虫、数据可视化、小程序、安卓app、大数据等设计与开发。 ✌服务内容:免费功能设计、免费提供开题答辩P…

【SpringCloud】企业认证、分布式事务,分布式锁方案落地-2

目录 高并发缓存三问 - 穿透 缓存穿透 概念 现象举例 解决方案 缓存穿透 - 预热架构 缓存穿透 - 布隆过滤器 布隆过滤器 布隆过滤器基本思想​编辑 了解 高并发缓存三问 - 击穿 缓存击穿 高并发缓存三问 - 雪崩 缓存雪崩 解决方案 总结 为什么要使用数据字典&…

对Linux目录结构的补充

📑打牌 : da pai ge的个人主页 🌤️个人专栏 : da pai ge的博客专栏 ☁️宝剑锋从磨砺出,梅花香自苦寒来 ☁️运维工程师的职责:监…

白鲸开源CEO郭炜荣获「2024中国数智化转型升级先锋人物」称号

2024年7月24日,由数据猿主办,IDC协办,新华社中国经济信息社、上海大数据联盟、上海市数商协会、上海超级计算中心作为支持单位,举办“数智新质力拓未来 2024企业数智化转型升级发展论坛——暨AI大模型趋势论坛”数据猿“年中特别策…

数据结构_study(一)

术语 程序设计数据结构算法 数据结构:相互之间存在一种或多种特定关系的数据元素的集合 数据:输入到计算机中可以操作的对象,数值类型(整型,浮点型),非数值类型(字符,…

算法——二分查找(day10)

目录 69. x 的平方根 题目解析: 算法解析: 代码: 35. 搜索插入位置 题目解析: 算法解析: 代码: 69. x 的平方根 69. x 的平方根 - 力扣(LeetCode) 题目解析: 老…

Linux 安装mysql-client-core-8.0

在Linux上安装mysql-client-core-8.0 安装流程 下面是安装mysql-client-core-8.0的步骤和相应的命令: 步骤1:更新系统软件源 我们首先需要更新系统的软件源,以确保我们能够获取到最新的软件包列表。使用以下命令更新软件源: …

C语言——运算符及表达式

C语言——运算符及表达式 运算符运算符的分类(自增运算符)、--(自减运算符)赋值运算符逗号运算符(顺序求值运算符) 表达式 运算符 运算符的分类 C语言的运算符范围很宽,除了控制语句和输入输出…

go语音进阶 Goroutine

什么是 Goroutine? 在Go语言中 是通过 ‘协程’ 来实现并发, Goroutine 是 Go 语言特有的名词, 区别于进程 Process, 线程Thread, 协程 Coroutine, 因为 Go语言的作者们觉得是有所区别的,所有专门创造做 Go…

python-绝对值排序(赛氪OJ)

[题目描述] 输入 n 个整数,按照绝对值从大到小排序后输出。保证所有整数的绝对值不同。输入格式: 输入数据有多组,每组占一行,每行的第一个数字为 n ,接着是 n 个整数, n0 表示输入数据的结束,不做处理。输…

Ruoyi-WMS本地运行

所需软件 1、JDK:8 安装包:https://www.oracle.com/java/technologies/javase/javase8-archive-downloads.htmlopen in new window 安装文档:https://cloud.tencent.com/developer/article/1698454open in new window 2、Redis 3.0 安装包&a…

[Vulnhub] Acid-Reloaded SQLI+图片数据隐写提取+Pkexec权限提升+Overlayfs权限提升

信息收集 IP AddressOpening Ports192.168.101.158TCP:22,33447 $ nmap -p- 192.168.101.158 --min-rate 1000 -sC -sV Not shown: 65534 closed tcp ports (conn-refused) PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 6.7p1 Ubuntu 5ubuntu1.3 (Ubuntu Lin…

C语言玩一下标准输出——颜色、闪烁、加粗、下划线属性

文章目录 C语言玩一下标准输出——颜色、闪烁、加粗、下划线属性转换Tip切换内容介绍显示方式字体色背景色 常用光标控制附示例和运行结果 C语言玩一下标准输出——颜色、闪烁、加粗、下划线属性 标准输出格式其属性可控制,控制由一系列的控制码指定。标准输出函数可…

【微信小程序实战教程】之微信小程序 WXS 语法详解

WXS语法 WXS是微信小程序的一套脚本语言,其特性包括:模块、变量、注释、运算符、语句、数据类型、基础类库等。在本章我们主要介绍WXS语言的特性与基本用法,以及 WXS 与 JavaScript 之间的不同之处。 1 WXS介绍 在微信小程序中&#xff0c…

【Socket编程】了解应用层协议HTTP

HTTP协议 HTTP又叫超文本传输协议。它定义了客户端和服务端之间该如何通信,以交换或者传输超文本(如HTML文档)。HTTP协议是一个无连接、无状态的协议,即每次请求都需要建立新的连接,且服务器不会保存客户端的状态信息…

智谱OpenDay“大有可玩”:30秒将任意文字生成视频

Sora毫无疑问带来AI大模型的全新玩法,大模型可基于任意文字生成视频,这也是这个“大家庭”若干努力(包括Runway的Gen系列、微软的Nuwa、Meta的Emu、谷歌的Phenaki/VideoPoet、CogVideo等)的一个全新高度。 7月26日,这…

FastAPI(七十七)实战开发《在线课程学习系统》接口开发-- 课程编辑和查看评论

源码见:"fastapi_study_road-learning_system_online_courses: fastapi框架实战之--在线课程学习系统" 课程编辑 先来看下课程编辑 1.判断是否登录 2.判断课程是否存在 3.是否有权限(只有自己可以修改自己的课程) 4.名称是否重复…

Docker(十)-Docker运行elasticsearch7.4.2容器实例以及分词器相关的配置

1.下载镜像 1.1存储和检索数据 docker pull elasticsearch:7.4.2 1.2可视化检索数据 docker pull kibana:7.4.22.创建elasticsearch实例 创建本地挂载数据卷配置目录 mkdir -p /software/elasticsearch/config 创建本地挂载数据卷数据目录 mkdir -p /software/elasticse…