DEiT中如何处理mask数据的?与MAE的不同

在DeiT里面,是通过mask的方式,将mask+unmasked的patches输出进ViT中,但其实在下游任务输入的patches还是和训练时patches的数量N是一致的(encoder所有的patches)。

而MAE是在encoder中只encoder未被mask的patches

通过什么方式支持的?

  • 在处理文本时,可以根据最长的句子在批次中动态padding或截断长句子
  • 而在处理图像(如使用ViT)时,可以将图像划分为大小相等的patches,数量可以根据图像的大小动态变化。

在训练阶段,部分patches被mask为0,但是处理的所有patches加起来的总长度还是一样的。被mask的位置在模型内部仍然占位,保持了输入序列的“框架”。这样,即使实际参与计算的只是部分元素,模型也能够适应在推理时使用全部元素的情况。

具体的计算步骤如下:

  1. 确定mask哪些patches
  2. 将mask的patches位置设置为0
  3. 这些被mask和未被mask的所有patches一起被输入进attention模块
  4. 将被mask的patches的注意力分数手动设置为“无穷大负数”(-inf)
  5. 这些被mask的patches的softmax值就会变为0,也就意味着这些patches并未参与注意力的计算
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MaskedSelfAttention(nn.Module):def __init__(self, embed_size):super(MaskedSelfAttention, self).__init__()self.query = nn.Linear(embed_size, embed_size)self.key = nn.Linear(embed_size, embed_size)self.value = nn.Linear(embed_size, embed_size)def forward(self, x, mask=None):Q = self.query(x)K = self.key(x)V = self.value(x)# 计算自注意力得分attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))# 将mask值为0的位置在attention_scores中设置为一个非常大的负数attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))# 使得这些位置的softmax结果接近0attention_weights = F.softmax(attention_scores, dim=-1)# 算最终的注意力加权和output = torch.matmul(attention_weights, V)return output# 假设嵌入大小为512
embed_size = 512
# 创建一个mask,假设我们有4个patches,我们想要mask掉第2个和第4个patches
mask = torch.tensor([[1, 0, 1, 0]])
# 扩展mask维度以适应attention_scores的形状(假设批大小为1,序列长度为4),mask需要与attention_scores形状匹配,即(batch_size, 1, 1, seq_length)
mask = mask.unsqueeze(1).unsqueeze(2)# 初始化模型和数据
sa = MaskedSelfAttention(embed_size)
x = torch.randn(1, 4, embed_size)  # 假设有一个批大小为1,序列长度为4的输入output = sa(x, mask)
print(output)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2871735.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【启动npm run serve 奇怪的报错】

报错如下: INFO Starting development server... utils.js:587Uncaught TypeError [ERR_INVALID_ARG_VALUE]: The argument path must be a string or Uint8Array without null bytes. Received E:\\#\u0000#idea-workspace\\wonderful-search\\wonderful-search-v…

解决:springboot项目访问hdfs文件提示guava版本不兼容

1、问题描述 版本说明:我用的hadoop版本:3.1.3 项目可以正常启动,但是调用访问hdfs的服务时候报错,报错消息如下:com.google.common.base.preconditions.checkArgument(ZL java/lang/String;Ljava/lang/Object:)V 原因分析&#x…

电脑文件误删除如何恢复?分享三个简单数据恢复方法

在日常使用电脑的过程中,文件误删除的情况时有发生。无论是由于操作失误还是病毒感染,丢失的文件都可能对我们的工作和学习造成极大的影响。因此,掌握文件恢复的方法显得尤为重要。下面围绕“电脑文件误删除如何恢复”这一主题,给…

设计模式学习笔记 - 设计原则与思想总结:1.总结回顾面向对象、设计原则、编程规范、重构技巧等知识点

概述 对前面的内容的回顾,温故而知新,包括:面向对象、设计原则、规范与重构三个模块的内容。 1.代码质量评判标准 如何评价代码质量的高低? 代码质量的评价具有很强的主观性,描述代码质量的词汇也有很多&#xff0c…

瑞_Redis_短信登录_Redis代替session的业务流程

文章目录 项目介绍1 短信登录1.1 项目准备1.2 基于Session实现登录流程1.3 Redis代替session的业务流程1.3.1 设计key的结构1.3.2 设计key的具体细节1.3.3 整体访问流程1.3.4 代码实现 🙊 前言:本文章为瑞_系列专栏之《Redis》的实战篇的短信登录章节的R…

uni-app发布 h5 与ASP .NET MVC 后台 发布 到 IIS的同一端口 并配置跨域

iis 安装URL重写 选择对应的后台项目,进行url重写 编辑【模式】部分的内容的重写规则,我这里是h5中请求的前缀是api,大家可以根据自己的前缀进行修改。 编写【操作类型】为重写,并写重写url,按照图中设置即可。 uni…

R语言深度学习-6-模型优化与调试

本教程参考《RDeepLearningEssential》 这是本专栏的最后一篇文章,一路走来,大家应该都可以独立的建立一个自己的神经网络进行特征学习和预测了吧! 6.1 缺失值处理 在我们使用大量数据进行建模的时候,缺失值对模型表现的影响非常…

研究生总结

Note:本博客更多是关于自己的感悟,没有翻阅文件详细查证,如果存在错过,也请提出指正。 1. 半监督回归 相比于半监督分类,半监督回归相对冷门。回归和分类之间有着难以逾越的天谴,预测精度。分类中的类别是可数的&…

Flutter开发进阶之使用工具效率开发

Flutter开发进阶之使用工具效率开发 软件开发团队使用Flutter开发的原因通常是因为Flutter开发性能高、效率高、兼容性好、可拓展性高,作为软件PM来说主要考虑的是范围管理、进度管理、成本管理、资源管理、质量管理、风险管理和沟通管理等,可以看到Flu…

Zookeeper的ZAB协议原理详解

Zookeeper的ZAB协议原理详解 如何保证数据一致性。 Paxos, 吸收了主从。 zk 数据模型Watch机制 zab zookeeper原子广播协议。 ZAB概念 ZooKeeper是通过Zab协议来保证分布式事务的最终一致性。 Zab(ZooKeeper Atomic Broadcast,.ZooKeeper原子广播协议)支持…

Java SE 认识异常 (Java SE完结篇)

1. 异常的概念与体系结构 1.1 异常的概念 在我们的生活中,一个人如果表情痛苦,我们可能会问: 你是生病了吗? 需要我陪你去看医生吗? 程序也和人是一样的,均会发生一些"生病"的行为,比如: 数据格式不对, 数组越界,网络中断等, 我们把这种程序出现的"生病&qu…

Ps:文字工具

工具箱里的文字工具组中包含了四种工具: 横排文字工具 Horizontal Type Tool 直排文字工具 Vertical Type Tool 横排文字蒙版工具 Horizontal Type Mask Tool 直排文字蒙版工具 Vertical Type Mask Tool 快捷键:T 横排文字蒙版工具和直排文字蒙版工具…

iOS--第二章block

第二章block blocks 概要Blocks模式Blocks语法Blocks类型变量截获自动变量值_block 说明符 Blocks的实现Block的实质截获自动变量值_block说明符Block存储域 blocks 概要 Blocks是c语言的扩展,block是一个带有自动变量值的匿名函数,它也是一个数据类型&…

Lua面向对象

封装 实现了New方法,相当于创建了一个表obj,并设置元表可以通过obj表去访问__index指向表中的数据 继承 通过大G表传入字符串创建表,

安卓国产百度网盘与国外云盘软件onedrive对比

我更愿意使用国外软件公司的产品,而不是使用国内百度等制作的流氓软件。使用这些国产软件让我不放心,他们占用我的设备大量空间,在我的设备上推送运行各种无用的垃圾功能。瞒着我,做一些我不知道的事情。 百度网盘安装包大小&…

golang常用库之-golang常用库之-ladon包 | 基于策略的访问控制

文章目录 golang常用库之-ladon包 | 基于策略的访问控制概念使用策略 条件 Conditions自定义conditionLadon Condition使用示例 持久化访问控制(Warden) 结合 Gin 开发一个简易 ACL 接口参考 golang常用库之-ladon包 | 基于策略的访问控制 https://github.com/ory/ladon Lado…

Unity中UGUI中的PSD导入工具的原理和作用

先说一下PSD导入工具的作用,比如在和美术同事合作开发一个背包UI业务系统时,美术做好效果图后,程序在UGUI中制作好界面,美术说这个图差了2像素,那个图位置不对差了1像素,另外一个图大小不对等等一系列零碎的…

从零自制docker-4-【PID Namespace MOUNT Namespace】

文章目录 PID namespace代码mountnamespace通俗理解代码 PID namespace 每个命名空间都有独立的PID空间,即每个命名空间的进程都由一开始分配。 新建立的进程内部进程ID为1 代码 package main import ("log""os/exec""os""sy…

MySQL_数据库图形化界面软件_00000_00001

目录 NavicatSQLyogDBeaverMySQL Workbench可能出现的问题 Navicat 官网地址: 英文:https://www.navicat.com 中文:https://www.navicat.com.cn SQLyog 官网地址: 英文:https://webyog.com DBeaver 官网地址&…

Spring MVC文件上传配置

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl 文件上传 Spring MVC文件上传基于Servlet 3.0实现;示例代码如下: Overrideprotected void customizeRegistration(ServletRegistration.Dynamic reg…