ViT(论文解读):An Image is worth 16*16 words

研究问题

虽然transformer已经成为NLP领域的标准(BERT、GPT3、T5),但是在CV领域很有限。在CV中,自注意力要么和CNN一起用,要么替换CNN中某个组件后保持整体结构不变。本文证明了对CNN的这种依赖并不必要,在图像分类中,纯Vision Transformer直接作用于一系列图像块也可以取得不错的成果。尤其是当在大规模数据集上进行预训练再迁移到中小型数据集上效果类似于最好的CNN。Transformer的另外一个好处是需要更少的资源就能达到很好的效果。(此处资源少是和更耗卡的模型对比,这里指2500天TPUv3训练)

标题中提到一张图片等价于许多16*16的单词:本文方法将一张图片分割成许多patch,每个patch为16*16大小,所以这个图片就等价于许多16*16的patch的组合体。

摘要

虽然transformer已经成为NLP领域的标准(BERT、GPT3、T5),但是在CV领域很有限。在CV中,自注意力或者和CNN一起用,或者替换CNN中某个组件后保持整体结构不变。本文证明了对CNN的这种依赖并不必要,在图像分类中,纯Vision Transformer直接作用于一系列图像块也可以取得不错的成果。尤其是当在大规模数据集上进行预训练再迁移到中小型数据集上效果类似于最好的CNN。Transformer的另外一个好处是需要更少的资源就能达到很好的效果。(此处资源少是和更耗卡的模型对比,这里指2500天TPUv3训练)

一.Introduction

第一段:Transformer在NLP领域的应用。基于self-attention的模型架构,在NLP成为了必选架构。当前最主流的方式就是在大规模的数据集上进行预训练,再放到特定领域小规模的数据集上进行微调。鉴于transformer的高效性和可扩展性,可以训练1000亿参数,且并未出现性能饱和的现象。
第二段:将自注意力机制应用到计算机视觉相关工作。【将Transformer应用到CV中的困难:(Transformer在做注意力机制的时候是两两相互的,所以计算复杂度是序列长度的平方倍。)想要把Transformer应用到CV,首先就是需要把2d图片转换为一个集合(或序列),最简单的方法就是将每个像素点当成元素,这样导致计算复杂度非常高。所以在视觉领域,CNN仍然占据主导地位,比如RestNet和AlexNet。】Transformer应用到CV的相关工作,有两种方式,一种是CNN和自注意力机制一起混合使用,另一种是全部使用自注意力机制。两种方法都致力于降低序列长度。后提出孤立自注意力和轴自注意力,这些方法在理论上都很高效,但是在处理器上进行计算有难度。
第三段:本文受到了Transformer在NLP领域中可扩展的特性所启发,直接应用一个标准的Transformer作用于图片,尽量做少的修改。最后突出是在有监督的数据集上进行训练,突出原因是因为在NLP领域中基本都是使用无监督数据集进行训练,但是在CV中基本使用有监督的数据集。【Vision Transformer解决序列长度问题:将一张图片打成很多patch,每个Patch是16*16。假如图片大小为224*224,则sequence length(序列长度)就是N=224*224=50176,但是当使用patch时,一个patch就相当于一个元素,有效的长和宽就变成了224/16=14,最后的序列长度就是N=14*14=196。将每个patch当作一个元素,通过一个全连接层就会得到一个linear Embedding,再当作输入传给Transformer。这样一张图片就变成了很多图片块(类似于NLP中的单词),这就是题目中提到的一张图片就是多个16*16的单词的意思。】
这篇文章的主要目的就是Transformer在CV领域中扩展的有多好,就是在于超大数据集和超大模型的两方加持。
第四、五段:ViT和CNN效果比较:在中型的数据集上(ImageNet),如果不加较强的约束,则ViT的效果要比CNN要差一些。作者解释:Transformer跟CNN相比少了一些归纳偏置(一种先验知识或者提前做好的假设),一旦神经网络有了归纳偏置之后,就拥有了很多先验信息,所以只需要较少的数据就可以得到一个不错的模型。但是Transformer没有,就需要从数据中学习。为了验证解释,作者在更大的数据集上进行预训练,从而得到比CNN网络相近或更好的结果。
总结引言:第一段说明transformer在NLP领域表现很好,越大的数据集和模型就会有更好的表现,且没有饱和的现象,就会面临将transformer放入视觉中。第二段讲述前人工作(要么将自记忆力和CNN结合起来,要么用自注意力来取代CNN,从未直接将transformer直接放到视觉中),以及自己工作和前人工作的区别。第三段开始将本文将transformer直接应用到视觉中来,仅对图片做了预处理(把图片打成patch),完全将一个视觉问题转换为NLP问题。最后两段介绍结果,只要有足够的数据做预训练,vision transformer就可以取得很好的效果。

二.Related Work

第一段:Transformer在NLP领域的应用。一般来说,基于Transformer的大规模模型都是在大型语料库上进行预训练,再放到目标训练上进行微调。重点代表为BERT(类似于完形填空,把句子中一些词划掉,预测词)和GPT(language model,已有一个句子去预测一个词),这都是自监督的训练方式。
第二段:自注意力在视觉中的应用。(想要将transformer应用到CV中,第一个问题就是将图片这个二维数据转换为一个1d序列。)在视觉领域中想要使用自注意力,最简单的方式就是就是将每一个像素点当作一个元素,再两两做自注意力,考虑到计算复杂度的问题,就需要Transformer做近似。
(1)使用local neighbourhood(小窗口)来做自注意力,而不使用整张图片,就会使得序列大大减少,从而计算复杂度大大降低
(2)使用sparse transformer,只对稀疏的点做自注意力,成为全局注意力的近似
(3)将自注意力应用到大小不同的block上,在极端情况下使用轴注意力(先在横轴上做自注意力,再从纵轴上做注意力),大大降低了序列的长度。
这三种特制的注意力算法在计算机视觉上都有不错的效果,使得在CPU或GPU跑得快成为可能,但是需要复杂的工程去加速算子。
第三段:提出与本文最相似的一篇论文,区别之处在于ViT使用更大的patch和更大的数据集。
第四段:将CNN和自注意力结合的工作很多,其中包括目标检测、图像分类、视频处理和多模态等等。
第五段:Image GPT和本文也很相似。GPT是用于NLP中的生成性模型,image GPT也是生成性模型,与本文相似之处在于都使用了transformer。

三.Method

在模型设计上尽可能按照原始的Transformer来实现,这样做的好处是可以直接使用在NLP中比较高效的方法,可以直接使用。

A. 整体架构和前向传播

简单来说,模型由三个模块组成,分别为Embedding层(线性投射层,包含将图片打散为patch且转换成向量)、Transformer Encoder(使用标准的block)和MLP Head(用于最终分类的结构)。
前向传播过程:
(1)Patch Embedding:一张图片被分成n个patch,再把patch变成序列放入线性投射层得到Patch Embedding
(2)Position embedding:self-attention没有考虑输入的位置信息,不能对序列进行建模,但是图片的patch需要有顺序,引入Position Embedding实现对patch排序。
(3)Class token:相当于transformer中的cls,作为patch的全局输出。同时也有Position,位置信息永远是0。
(4)所有的token都在和其他的token做交互,class token的输出当做整个图片的特征,经过MLP Head(相当于一个分类头)得到分类结果。最后用交叉熵函数进行模型的训练。

B.图片预处理

标准的transformer模块输入要求是向量(或者token)序列,即二维矩阵[num_token, token_dim]。对于图像数据而言,其数据为[H, W, C]格式的三维矩阵,所以需要先通过一个Embedding层来对数据做变换。下面以ViT-B为例解释前向传播过程。
输入图片X的维度是224*224*3(3为RGB维度),每个patch大小为16*16,共有(224*224)/(16*16)=196个patch。每个patch的维度就是16*16*3=768。变成二维矩阵X[196(patch的个数),768(patch的维度)]。
接着使用线性映射E将每个patch映射成一维向量,这个线性映射就是全连接层,维度为768*768(前768为16*16*3得来,后768为文章中D,当模型变复杂,D可以变化)。
X*E = [196,768]*[768,768]=[196,768],从而得到patch embedding,表示有196个token,每个token的维度是768,此时已经完成将一个CV问题转化为NLP问题。
额外的cls token维度也是768,这样可以方便和后面图像的信息直接进行拼接。所以最后整体进入Transformer的序列的长度是197(196+1)*768。
Position Embedding是可以学习的,每一个向量代表一个位置信息(向量的维度是768),将这些位置信息加到所有的token中,序列还是197*768。

C.Transformer Encoder

Transformer Encoder就是将Transformer Block叠加L次,经过预处理后,输入变成了[197*768]的tensor数据。进入Layer Norm后维度不变,仍然是[197*768]。后进入多头注意力层,以ViT的base版本为例,使用12头,则每个头的q、k和v维度分别变成[197*64(768/12)],进行12组自注意力操作,最后再拼接起来,输出维度仍然是[197*768]。再经过一层layer norm,还是197*768。最后进入MLP Block,全连接+GELU激活函数+Dropout组成。把维度放大到4倍[197, 768] -> [197, 3072],再还原回原节点个数[197, 3072] -> [197, 768]。
进去Transformer block之前是197*768,出来还是197*768,这个序列的长度和每个token对应的维度大小都是一样的,所以就可以在一个Transformer block上不停地往上叠加Transformer block,最后有L层Transformer block的模型就构成了Transformer encoder。Transformer从头到尾都是使用D当作向量的长度的,都是768,同一个模型里这个维度是不变的。如果transformer变得更大了,D也可以相应的变得更大。

D.MLP Head和ViT-B/16模型结构图

对于分类,只需要提取[class]token生成的对应结果就行,即[197,768]中提取出[class]token对应的[1,768],通过MLP head得到最终的分类结果。MLP Head原论文中说再训练ImageNet21k时,是由Linear+tanh激活函数+Linear组成,但是迁移到ImageNet1k或者自己的数据集上时,只定义一个Linear即可。注意,在Transformer Encoder后还有一个Layer Norm。

3.1 vision Transformer
3.2 fine-tuning and higher resolution

四.消融实验

4.1  class token

对于Transformer来说,如果有一个Transformer模型,输入和输出都有n个元素,为什么不直接在n个输出上做全局平均池化得到最后一个特征,而是在最前面加一个class token,最后用class token的输出做分类?
要跟原始的Transformer尽可能保持一致,所以使用了class token(class token 在NLP的分类任务中也有用到,也是当作全局对句子的理解的特征)。本文的class token是将它当作一个图像的整体特征,将这个token的输出送入MLP(MLP中是用tanh当作非线性的激活函数来做分类的预测)。
之前视觉领域是不需要class token,以ResNet-50为例,在最后一个stage出来的是14*14的feature map,然后在这个feature map上其实做了全局平局池化(global average pooling),池化以后的特征已经拉直了,就已经是一个向量了,此时就可以把这个向量理解为一个全局的图片特征,然后再拿这恶鬼特征去做分类。

4.2 位置编码

位置编码涉及到三种消融实验,分别为1D(NLP和本文所使用的)、2D(类似与矩阵下标表示方法)和Relative position embedding(类似与图像块中的位置信息,不同图像块之间可以用绝对距离表示,也可以用相对距离表示,即文中的offset)。

五.Experiments实验部分

4.1 setup
4.2 comparison to state of the art
4.3 Pre-training data requirements
4.4 scaling study
4.5 inspection vision transformer
4.6 self-supervision

六.Conclusion

本文的工作是使用NLP领域标准的Transformer来解决CV的问题,与之前简单结合自注意力方法的区别在于:第一是对图像进行抽图像块,第二是位置编码使用了一些图像所特有的归纳偏置,除此之外没有其他的归纳偏置了。这样做的好处是就是将图像块当作NLP中的token,标准的Transformer就可以来解决CV问题了。该方法效果非常好且便宜。

展望:

(1)不能只做分类,还有检测和分割
后有DETR用于目标检测;ViT-FRCNN用于检测;SETR用于分割;
(2)探索自监督的预训练方法
(3)将ViT变的很大,可能会有很好的效果,Scaling Vision Transformer实现。

创新点

1.将Tansformer直接应用到CV中,只需要对输入图片做预处理(将图片打成patch),完全将一个CV问题转换为NLP问题。
2.ViT借鉴了BERT中的预训练方法(先在大规模的数据集上进行预训练,再放到具体任务的特定数据集上进行微调)和GPT中自注意力机制(与原来CNN相比,每个token都可以访问自身原来信息,能够更好的理解图像中的上下文信息)。
3.任务方面:ViT只做了分类,后续可以做检测、分割;改变结构方面:可以修改tokenization或者中间的transformer block;目标函数方面:监督训练和自监督训练的方式。
4.本文主要两个重点创新:(1)将图像打成patch(2)加入位置编码,提供图像所特有的归纳偏置

知识点补充

1.孤立自注意力:由于Transformer应用于CV有序列缩小的要求,在注意力机制时不适用整张图,而是使用local window(局部的小窗口),通过控制这个窗口的大小来使得计算序列在允许范围内。(类似与卷积)
2.轴自注意力:CV中计算复杂度高是由于序列N=H*W,是一个二维的矩阵,将这个二维矩阵拆分成两个一维向量,先后在高度、宽度的维度做一次self-attention,这样就是将一个二维矩阵的计算转换成了两个一维向量的运算,大大降低了计算复杂度。
3.CNN中的归纳偏置(inductive bias):可以理解为先验知识或者提前做好的假设。有两种,分别为locality和translation equivariance(平移等变性)。前者是指CNN是以滑动窗口的形式一点一点在图片上进行卷积的,假设图片上相邻区域会有相似的特征,靠的近相关性更强。后者用公式表示为f(g(x))=g(f(x)),表示进行f操作和g操作的顺序不会影响结果。
4.消融实验:是一种评估算法性能的方法。核心思想就是删除或修改系统中的特定部分,从而观察这些部分是如何影响系统的功能、性能或行为的。进而了解系统的关键要素,同时验证该部分对系统的重要程度。
5.在计算机领域,无监督训练和有监督训练的主要区别为是否使用已知的标签数据进行训练。有监督训练需要大量的标签数据,而无监督训练不需要。
6.分类头(classfication head)的作用:将模型的原始输入映射到具体的类别标签。在预训练阶段由带有一个隐藏层了MLP实现,在微调阶段由一个线性层实现。其具体实现就是将前一层或前多层的特征图转换为具体的类别标签,通常将每个类别的分数或者概率作为输出的。
7.BatchNorm和LayerNorm的作用和区别:两者都是归一化,加速收敛。前者是对同一feature的batchsize个样本减均值除方差,后者是对同一样本的所有feature进行该操作。

核心结论

1.是对图像进行抽图像块
2.位置编码使用了一些图像所特有的归纳偏置,除此之外没有其他的归纳偏置了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3249931.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

简单工厂、工厂方法与抽象工厂之间的区别

简单工厂、工厂方法与抽象工厂之间的区别 1、简单工厂(Simple Factory)1.1 定义1.2 特点1.3 示例场景 2、工厂方法(Factory Method)2.1 定义2.2 特点2.3 示例场景 3、抽象工厂(Abstract Factory)3.1 定义3.…

关于Centos停更yum无法使用的解决方案

最近在使用Centos7.9系统时候,发现yum仓库无法进行安装软件包了,官方说2024年6月30日进行停更,停更后无法提供对应的软件服务。 我在使用yum安装包的时候发现确实不能使用官方服务了: CentOS停更的影响 CentOS停止更新之后&#…

CentOS 7报错:yum命令报错 “ Cannot find a valid baseurl for repo: base/7/x86_6 ”

参考连接: 【linux】CentOS 7报错:yum命令报错 “ Cannot find a valid baseurl for repo: base/7/x86_6 ”_centos linux yum search ifconfig cannot find a val-CSDN博客 Centos7出现问题Cannot find a valid baseurl for repo: base/7/x86_64&…

韦东山嵌入式linux系列-驱动进化之路:设备树的引入及简明教程

1 设备树的引入与作用 以 LED 驱动为例,如果你要更换LED所用的GPIO引脚,需要修改驱动程序源码、重新编译驱动、重新加载驱动。 在内核中,使用同一个芯片的板子,它们所用的外设资源不一样,比如A板用 GPIO A&#xff0c…

大数据采集工具——Flume简介安装配置使用教程

Flume简介&安装配置&使用教程 1、Flume简介 一:概要 Flume 是一个可配置、可靠、高可用的大数据采集工具,主要用于将大量的数据从各种数据源(如日志文件、数据库、本地磁盘等)采集到数据存储系统(主要为Had…

2024-07-19 Unity插件 Odin Inspector9 —— Validation Attributes

文章目录 1 说明2 验证特性2.1 AssetsOnly / SceneObjectsOnly2.2 ChildGameObjectsOnly2.3 DisallowModificationsIn2.4 FilePath2.5 FolderPath2.6 MaxValue / MinValue2.7 MinMaxSlider2.8 PropertyRange2.9 Required2.10 RequiredIn2.11 RequiredListLength2.12 ValidateIn…

JAVA:Filer过滤器+案例:请求IP访问限制和请求返回值修改

JAVA:Filer过滤器 介绍 Java中的Filter也被称为过滤器,它是Servlet技术的一部分,用于在web服务器上拦截请求和响应,以检查或转换其内容。 Filter的urlPatterns可以过滤特定地址http的请求,也可以利用Filter对访问请求…

鸿蒙语言基础类库:【@system.sensor (传感器)】

传感器 说明: 从API Version 8开始,该接口不再维护,推荐使用新接口[ohos.sensor]。本模块首批接口从API version 4开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。该功能使用需要对应硬件支持,仅支持…

地图项目涉及知识点总结

序:最近做了一个在地图上标记点的项目,用户要求是在地图上显示百万量级的标记点,并且地图仍要可用(能拖拽,能缩放)。调研了不少方法和方案,最终实现了相对流畅的地图系统,加载耗时用…

2024可信数据库发展大会:TDengine CEO 陶建辉谈“做难而正确的事情”

在当前数字经济快速发展的背景下,可信数据库技术日益成为各行业信息化建设的关键支撑点。金融、电信、能源和政务等领域对数据处理和管理的需求不断增加,推动了数据库技术的创新与进步。与此同时,人工智能与数据库的深度融合、搜索与分析型数…

【Git】(基础篇四)—— GitHub使用

GitHub使用 经过上一篇的文章,相信大家已经对git的基本操作熟悉了,但哪些使用git的方法只是在本地仓库进行,本文介绍如何使用git和远程仓库进行连接使用。 Github和Gitee 主要用到的两个远程仓库在线平台是github和gitee GitHub GitHub …

Adobe XD中文设置指南:专业设计师的现场解答

Adobe XD是世界领先的在线合作UI设计工具。它摆脱了Sketch、Figma等传统设计软件对设备的依赖,使设计师可以随时随地使用任何设备打开网页浏览器,轻松实现跨平台、跨时空的设计合作。然后,为了提高国内设计师的使用体验,Adobe XD如…

2024-07-18 Unity插件 Odin Inspector8 —— Type Specific Attributes

文章目录 1 说明2 特定类型特性2.1 AssetList2.2 AssetSelector2.3 ChildGameObjectsOnly2.4 ColorPalette2.5 DisplayAsString2.6 EnumPaging2.7 EnumToggleButtons2.8 FilePath2.9 FolderPath2.10 HideInInlineEditors2.11 HideInTables2.12 HideMonoScript2.13 HideReferenc…

DP(6) | 完全背包 | Java | LeetCode 322, 179, 139 做题总结

322. 零钱兑换 我的错误答案 class Solution {public int coinChange(int[] coins, int amount) {int[][]dp new int [coins.length][amount1];for(int j0; j<amount; j) {if(coins[0] j){dp[0][coins[0]] 1;}}for(int i1; i<coins.length; i) {for(int j0; j<am…

带时间窗车辆路径问题丨论文复现:改进粒子群算法求解

路径优化相关文章 1、路径优化历史文章2、路径优化丨带时间窗和载重约束的CVRPTW问题-改进遗传算法&#xff1a;算例RC1083、路径优化丨带时间窗和载重约束的CVRPTW问题-改进和声搜索算法&#xff1a;算例RC1084、路径优化丨复现论文-网约拼车出行的乘客车辆匹配及路径优化5、…

[C/C++入门][进制原理]27、计算机种的进制

各种信息进入计算机&#xff0c;都要转换成“0”和“1”的二进制形式。 计算机 采用二进制的原因是&#xff1a; 物理上容易实现&#xff0c;可靠性高。&#xff08;电子元件的通电和不通电就可以表示1和0&#xff0c;所以非常方便&#xff09;运算简单&#xff0c;通用性强。…

ELK日志分析系统部署文档

一、ELK说明 ELK是Elasticsearch&#xff08;ES&#xff09; Logstash Kibana 这三个开源工具组成&#xff0c;官方网站: The Elastic Search AI Platform — Drive real-time insights | Elastic 简单的ELK架构 ES: 是一个分布式、高扩展、高实时的搜索与数据分析引擎。它…

Java 网络编程(TCP编程 和 UDP编程)

1. Java 网络编程&#xff08;TCP编程 和 UDP编程&#xff09; 文章目录 1. Java 网络编程&#xff08;TCP编程 和 UDP编程&#xff09;2. 网络编程的概念3. IP 地址3.1 IP地址相关的&#xff1a;域名与DNS 4. 端口号&#xff08;port&#xff09;5. 通信协议5.1 通信协议相关的…

如何免费用java c#实现手机在网状态查询

今天分享手机在网状态查询接口&#xff0c;该接口适用的场景非常广泛&#xff01;首先我们先讲下什么是手机在网状态&#xff1f;简单来说&#xff0c;就是你得手机号是否还在正常使用中&#xff0c;是否能够及时接收和回复信息&#xff0c;是否能够随时接听和拨打电话。如果你…

小白新手搭建个人网盘

小白新手搭建个人网盘 序云服务器ECS重置密码远程连接ECS实例 安装OwnCloud安装Apache服务PHP运行环境NAS挂载挂载验证操作体验 序 阿里云文件存储NAS&#xff08;Apsara File Storage NAS&#xff09;是一个可大规模共享访问&#xff0c;弹性扩展的分布式文件系统。本文主要是…