【AI】计算机视觉VIT文章(Transformer)源码解析

论文:Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020

源码的Pytorch版:https://github.com/lucidrains/vit-pytorch

0.前言

Transformer提出后在NLP领域中取得了极好的效果,其全Attention的结构,不仅增强了特征提取能力,还保持了并行计算的特点,可以又快又好的完成NLP领域内几乎所有任务,极大地推动自然语言处理的发展。
VIT这篇文章就是将Transformer模型应用在了CV领域,它将图像处理成Transformer模型可以应用的形式,沿用NLP领域中Transformer的方法,直接验证了其精度可以和ResNet不相上下,展示了在计算机视觉中使用纯Transformer结构的可能,为Transformer在CV领域的应用打开了大门。

直接读文章通常比较抽象,英文的原文更能劝退一大部分人,但对于程序员来说,代码是通行于世界的语言,理解起来就比较简单,结合源码理解论文中的结构,就比较事半功倍。

在这里插入图片描述

上图是VIT文章中的结构,我们看图提问题,从数据的流向来看:图像怎么切分重排的?Linear Projection of Flattened Patches对图像作了什么,怎么让图像变成Transformer能够输入的格式?
Position Embedding是怎么做图像位置编码的,为什么会多出来一个0的位置编码?Transformer Encoder中的各个结构分别代表什么,是怎么实现的?输出的类别是什么?

1.图像是如何适配Transformer输入的?

Transformer的输入是一个一维的向量,而我们的图像是二维的,需要把图像拉伸成一维的,最简单的方式就是沿着x轴展开,将所有的行拼接在一起,也就是Flatten的操作,但是这样处理会导致向量维度比较大,而且同一张图片也只能生成一个Embedding,不能适配Multi-head Embedding(这种解释有点牵强,有点拿结果去解释原因)。

1.1 图像切分重排

如结构图所示,VIT采取了图像切分重排的方式,将一个完整的图像按照行列的方向切分成小块,然后再进行后续处理。切分重排是怎么实现的,我们看代码:

self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), #图片切分重排nn.Linear(patch_dim, dim), # Linear Projection of Flattened Patches)

这里的Rearrange使用到了einops的库,相关的介绍可以查看文档
这里主要是对b c (h p1) (w p2) -> b (h w) (p1 p2 c)表达式的理解,b是batchsize,c是chanel,h和w是图像的高度和宽度,表达式前边的(h p1)表示将输入的h重新拆分成一个h*p1的向量,同理(w p2)是将输入的w拆分成w*p2的向量,这里的p1和p2是模型定义的patch_height和patch_width,可以理解为切分后小图的高和宽;
表达式右边表示输出向量的维度,(p1 p2 c)表示这3个数相乘,表示的是切分后一个小图的一维向量大小,(h w)则表示总共有多少个切分后的小图所生成的向量。这里的h和w的值跟输入的h和w值不同,表示的是原有h和w除以patch_height或patch_width的值,也就是在高和宽上各能切分出几个小图。
通过这样的处理,就将输入的c个channel的h*w的二维向量,转换成小图展开后的一维向量。

1.2 构建patch0

图像在输入Transformer之前,还连接了一个patch0,这一步的操作可以认为是延续了nlp的操作,在VIT这边的操作,可以认为是将分类类别拼接上去。

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)

1.3 Positional Embedding

为了在模型中引入位置信息,VIT引入了位置编码的形式,也就是图中的0、1、2、3…

x += self.pos_embedding[:, :(n + 1)]

然后将position_embedding与图像的Embedding相加。

2.Transformer Encoder中的各个结构分别代表什么,是怎么实现的?

结构中的Norm可以认为是归一化处理,MLP是多个全连接层,理解起来都比较简单。其实最主要的结构就是Multi-head Attention。
在这里插入图片描述

2.1 Attention 定义

class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):b, n, _, h = *x.shape, self.headsqkv = self.to_qkv(x).chunk(3, dim = -1)#得到qkvq, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scaleattn = self.attend(dots)out = einsum('b h i j, b h j d -> b h i d', attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)

2.2 Attention 的理解

VIT不同于之前RNN的地方就是引入了Attention机制,其本质可以认为是相似度计算,是计算每一个输入值与其他输入值的相似度,然后带入到计算中,如图所示,q是输入的查询值,k是关键词,v是计算值,计算每一个q与其他k的相似度,然后再带入计算中去。
在这里插入图片描述

其对应的计算公式为:
在这里插入图片描述
相似度计算是用的点积的形式,除以根号下dk是为了抑制极端值,保证softmax之后数值不至于丧失梯度。对应的代码如下:

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

然后过一个softmax

attn = self.attend(dots) # self.attend = nn.Softmax(dim = -1)
attn = self.dropout(attn)

然后跟value值相乘

out = torch.matmul(attn, v)

2.3 Multihead Attention

多头注意力机制可以认为是定义多个Attention(每个attention关注的重点不同)分别来对数据进行处理,这里从源码中的循环结构可以体现出来:

class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),FeedForward(dim, mlp_dim, dropout = dropout)]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)

3.使用示例

这里可以参考官方的readme文档

$ pip install vit-pytorch

使用示例

import torch
from vit_pytorch import ViTv = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1
)img = torch.randn(1, 3, 256, 256)preds = v(img) # (1, 1000)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2659564.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Erlang、RabbitMQ下载与安装教程(windows超详细)

目录 安装Erlang 1.首先安装RabbitMQ需要安装Erlang环境 2.点击下载好的.exe文件进行傻瓜式安装,一直next即可 3.配置Erlang环境变量 安装RabbitMQ 1.给出RabbitMQ官网下载址:Installing on Windows — RabbitMQ,找到 2.配置RabbitMQ环境变量&#xff0…

软件测试/测试开发丨Pytest学习笔记

Pytest 格式要求 文件: 以 test_ 开头或以 _test 结尾类: 以 Test 开头方法/函数: 以 _test 开头测试类中不可以添加构造函数, 若添加构造函数将导致Pytest无法识别类下的测试方法 断言 与Unittest不同, 在Pytest中我们需要使用python自带的 assert 关键字进行断言 assert…

【强化学习】基于蒙特卡洛MC与时序差分TD的简易21点游戏应用

1. 本文将强化学习方法(MC、Sarsa、Q learning)应用于“S21点的简单纸牌游戏”。 类似于Sutton和Barto的21点游戏示例,但请注意,纸牌游戏的规则是不同且非标准的。 2. 为方便描述,过程使用代码截图,文末附链…

【k8s源码分析-Apiserver-2】kube-apiserver 结构概览以及主体部分源码分析

参考 Kubernetes 源码剖析(书籍)kube-apiserver的设计与实现 - 自记小屋 kube-apiserver 核心思想 APIGroupInfo 记录 GVK 与 Storage 的对应关系 将 GVK 转换成,Restful HTTP Path将 Storage 封装成 HTTP Handler将上面两个形成映射&#…

CSP CCF 201409-2 画图 C++满分题解

解题思路&#xff1a; 1.使用二维数组标记每一个方块是否被涂色。 2.注意坐标代表的是点&#xff0c;不是方块&#xff0c;交界处的坐标只能算一个方块。 3.可以看成&#xff1a;每一个坐标都对应它左上角的一个小方块&#xff0c;这样可以避免重复计算方块数 #include<i…

3 个月前被裁员了,心情跌落谷底,直到我看到了这本神书…

3个月前的某一天&#xff0c;正在愉快的打工&#xff0c;突然被喊去谈话&#xff0c;然后就被辞退了。。 加入了找工作的大军 然而&#xff0c;因为疫情&#xff0c;因为大专学历的我&#xff0c;找工作比以往都艰难了许多 很多&#xff0c;纯粹就是因为学历&#xff0c;都不…

第十三届蓝桥杯大赛青少年国赛C++组编程题真题(2022年)

第十三届蓝桥杯大赛青少年国赛C组编程题真题&#xff08;2022年&#xff09; 编程题第 1 题 问答题 电线上的小鸟 题目描述&#xff1a; 在一根电线上落有N只小鸟&#xff0c;有的小鸟头向左看&#xff0c;有的小鸟头向右看&#xff0c;且每只小鸟只能看到它视线前的那一只小…

Java之ThreadLocal 详解

ThreadLocal 详解 原文地址&#xff1a;https://juejin.cn/post/6844904151567040519open in new window。 什么是ThreadLocal&#xff1f; ThreadLocal提供线程局部变量。这些变量与正常的变量不同&#xff0c;因为每一个线程在访问ThreadLocal实例的时候&#xff08;通过其…

机器学习之人工神经网络(Artificial Neural Networks,ANN)

人工神经网络(Artificial Neural Networks,ANN)是机器学习中的一种模型,灵感来源于人脑的神经网络结构。它由神经元(或称为节点)构成的层级结构组成,每个神经元接收输入并生成输出,这些输入和输出通过权重进行连接。 人工神经网络(ANN)是一种模仿生物神经系统构建的…

电气产品外壳常用材质PA、PC、PBT、ABS究竟是什么?

在如今工业制造领域&#xff0c;各种改性塑料、复合材料以及轻质合金材料的运用日趋成熟。在电气领域&#xff0c;不同电气产品的外壳、组件材质采用不同材料&#xff0c;以同为科技&#xff08;TOWE&#xff09;电气产品为例&#xff0c;工业连接器系列产品采用PA6外壳材质、机…

软件测试/测试开发丨Python常用数据结构学习笔记

Python常用数据结构 list 列表 列表定义 列表是有序的可变元素的集合&#xff0c;使用中括号[]包围&#xff0c;元素之间用逗号分隔列表是动态的&#xff0c;可以随时扩展和收缩列表是异构的&#xff0c;可以同时存放不同类型的对象列表中允许出现重复元素 列表使用&#x…

鸿蒙 DevEco Studio 3.1 入门指南

本文主要记录开发者入门&#xff0c;从软件安装到项目运行&#xff0c;以及后续的学习 1&#xff0c;配置开发环境 1.1 下载安装包 1.2 下载SDK及工具链 1.3 诊断开发环境 1.4 配置环境变量 配置HDC工具环境变量 配置Node环境变量 2.创建和运行 2.1 新建项目 2.2 页面…

掌握激活函数(一):深度学习的成功之源

文章目录 引言基本概念常用激活函数举例Sigmoid激活函数公式Sigmoid函数的数学特性示例基于NumPy和PyTorch实现Sigmoid函数将Sigmoid函数应用于二分类任务 Sigmoid激活函数的局限性举例 ReLU激活函数公式ReLU函数的数学特性ReLU函数的特点示例基于NumPy和PyTorch实现ReLU函数搭…

虚析构和纯虚析构

多态使用时&#xff0c;如果子类中有属性开辟到堆区&#xff0c;那么父类的指针在释放时无法调用到子类的析构代码 解决方式&#xff1a;将父类中的析构代码函数改为虚析构或者纯虚析构 虚析构和纯虚析构共性&#xff1a; 可以解决父类指针释放子类对象 都需要有具体的函数…

vu3-14

第一个需求是在用户登录成功之后&#xff0c;在主页显示用户的真实姓名和性别&#xff0c;这些信息要调用后端API获取数据库里面的信息&#xff0c;第二个需求是点击菜单1&#xff0c;在表单中修改用户信息之后&#xff0c;更新到后端数据库&#xff0c;然后在主页同步更新用户…

软件测试/测试开发丨接口测试学习笔记分享

一、Mock 测试 1、Mock 测试的场景 前后端数据交互第三方系统数据交互硬件设备解耦 2、Mock 测试的价值与意义 不依赖第三方数据节省工作量节省联调 3、Mock 核心要素 匹配规则&#xff1a;mock的接口&#xff0c;改哪些接口&#xff0c;接口哪里的数据模拟响应 4、mock实…

【基础篇】六、自定义类加载器打破双亲委派机制

文章目录 1、ClassLoader抽象类的方法源码2、打破双亲委派机制&#xff1a;自定义类加载器重写loadclass方法3、自定义类加载器默认的父类加载器4、两个自定义类加载器加载相同限定名的类&#xff0c;不会冲突吗&#xff1f;5、一点思考 1、ClassLoader抽象类的方法源码 ClassL…

【Python基础篇】【19.异常处理】(附案例,源码)

异常处理 异常处理常见异常elsefinallyraise获取异常信息sys.exc_info()traceback 处理异常基本原则assert断点调试两种方式Debugger窗口各图标的含义1.Show Execution Point &#xff08;Alt F10&#xff09;2.Step Over&#xff08;F8&#xff09;3.Step Into &#xff08;F…

【线性代数】通过矩阵乘法得到的线性方程组和原来的线性方程组同解吗?

一、通过矩阵乘法得到的线性方程组和原来的线性方程组同解吗&#xff1f; 如果你进行的矩阵乘法涉及一个线性方程组 Ax b&#xff0c;并且你乘以一个可逆矩阵 M&#xff0c;且产生新的方程组 M(Ax) Mb&#xff0c;那么这两个系统是等价的&#xff1b;它们具有相同的解集。这…

k8s二进制部署2

部署 Worker Node 组件 //在所有 node 节点上操作 #创建kubernetes工作目录 mkdir -p /opt/kubernetes/{bin,cfg,ssl,logs} #上传 node.zip 到 /opt 目录中&#xff0c;解压 node.zip 压缩包&#xff0c;获得kubelet.sh、proxy.sh cd /opt/ unzip node.zip chmod x kubelet.…