昇思25天学习打卡营第16天 | Vision Transformer图像分类

昇思25天学习打卡营第16天 | Vision Transformer图像分类

文章目录

  • 昇思25天学习打卡营第16天 | Vision Transformer图像分类
    • Vision Transform(ViT)模型
      • Transformer
        • Attention模块
        • Encoder模块
      • ViT模型输入
    • 模型构建
      • Multi-Head Attention模块
      • Encoder模块
      • Patch Embedding模块
      • ViT网络
    • 总结
    • 打卡

Vision Transform(ViT)模型

ViT是NLP和CV领域的融合,可以在不依赖于卷积操作的情况下在图像分类任务上达到很好的效果。

ViT模型的主体结构是基于Transformer的Encoder部分。

Transformer

Transformer由很多Encoder和Decoder模块构成,包括多头注意力(Multi-Head Attention)层,Feed Forward层,Normalization层和残差连接(Residual Connection)。
encoder-decoder
多头注意力结构基于自注意力机制(Self-Attention),是多个Self-Attention的并行组成。

Attention模块

Attention的核心在于为输入向量的每个单词学习一个权重。

  1. 最初的输入向量首先经过Embedding层映射为Q(Query),K(Key),V(Value)三个向量。
  2. 通过将Q和所有K进行点乘初一维度平方根,得到向量间的相似度,通过softmax获取每词向量之间的关系权重。
  3. 利用关系权重对词向量的V加权求和,得到自注意力值。
    self-attention
    多头注意力机制只是对self-attention的并行化:
    multi-head-attention
Encoder模块

ViT中的Encoder相对于标准Transformer,主要在于将Normolization放在self-attention和Feed Forward之前,其他结构与标准Transformer相同。
vit-encoder

ViT模型输入

传统Transformer主要应用于自然语言处理的一维词向量,而图像时二维矩阵的堆叠。
在ViT中:

  1. 通过卷积将输入图像在每个channel上划分为 16 × 16 16\times 16 16×16个patch。如果输入 224 × 224 224\times224 224×224的图像,则每一个patch的大小为 14 × 14 14\times 14 14×14
  2. 将每一个patch拉伸为一个一维向量,得到近似词向量堆叠的效果。如将 14 × 14 14\times14 14×14展开为 196 196 196的向量。
    这一部分Patch Embedding用来替换Transformer中Word Embedding,用作网络中的图像输入。

模型构建

Multi-Head Attention模块

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Encoder模块

from typing import Optional, Dictclass FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):"""ResidualCell construct."""return self.cell(x) + xclass TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

Patch Embedding模块

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

ViT网络

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

总结

这一节对Transformer进行介绍,包括Attention机制、并行化的Attention以及Encoder模块。由于传统Transformer主要作用于一维的词向量,因此二维图像需要被转换为类似的一维词向量堆叠,在ViT中通过将Patch Embedding解决这一问题,并用来代替传统Transformer中的Word Embedding作为网络的输入。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3247387.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

BiLSTM 实现股票多变量时间序列预测(PyTorch版)

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对…

三、GPIO口

我们在刚接触C语言时,写的第一个程序必定是hello world,其他的编程语言也是这样类似的代码是告诉我们进入了编程的世界,在单片机中也不例外,不过我们的传统就是点亮第一个LED灯,点亮电阻,电容的兄弟&#x…

【Java项目笔记】01项目介绍

一、技术框架 1.后端服务 Spring Boot为主体框架 Spring MVC为Web框架 MyBatis、MyBatis Plus为持久层框架,负责数据库的读写 阿里云短信服务 2.存储服务 MySql redis缓存数据 MinIO为对象存储,存储非结构化数据(图片、视频、音频&a…

防溺水预警系统引领水域安全新篇章

一、系统概述 随着人们对水域活动的需求增加,溺水事故频发,给人们的生命安全带来了严重威胁。然而,如今,一项创新科技正在以强大的功能和无限的潜力引领着水域安全的新篇章。智能防溺水预警系统,作为一种集成了智能感知…

神经网络构造

目录 一、神经网络骨架:二、卷积操作:三、卷积层:四、池化层:五、激活函数(以ReLU为例): 一、神经网络骨架: import torch from torch import nn#神经网络 class CLH(nn.Module):de…

【概率论三】参数估计:点估计(矩估计、极大似然法)、区间估计

文章目录 一. 点估计1. 矩估计法2. 极大似然法2.1. 似然函数2.2. 极大似然估计法 3. 评价估计量的标准3.1. 无偏性3.2. 有效性3.3. 一致性 二. 区间估计1. 区间估计的概念2. 正态总体参数的区间估计 参数估计讲什么 由样本来确定未知参数参数估计分为点估计与区间估计 一. 点估…

Golang面试题整理(持续更新...)

文章目录 Golang面试题总结一、基础知识1、defer相关2、rune 类型3、context包4、Go 竞态、内存逃逸分析5、Goroutine 和线程的区别6、Go 里面并发安全的数据类型7、Go 中常用的并发模型8、Go 中安全读写共享变量方式9、Go 面向对象是如何实现的10、make 和 new 的区别11、Go 关…

【Pytorch】RNN for Name Classification

参考学习来自: https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.htmlRNN完成姓名分类https://download.pytorch.org/tutorial/data.zip 导入库 import glob # 用于查找符合规则的文件名 import os import unicodedata import stri…

【linux】信号的理论概述和实操

目录 理论篇 信号概述 信号的分类 信号机制 理解硬件中断 异步 信号对应的三种动作 信号产生的条件 终端按键 系统调用 软件条件 硬件异常 除0错误 野指针 OS对于错误的态度 信号在进程中的内核数据结构 信号的处理 CPU的内核态和用户态概述 进程处理信号的时…

dom4j 操作 xml 之按照顺序插入标签

最近学了一下 dom4j 操作 xml 文件,特此记录一下。 public class Dom4jNullTagFiller {public static void main(String[] args) throws DocumentException {SAXReader reader new SAXReader();//加载 xml 文件Document document reader.read("C:\\Users\\24…

Hadoop3:MR程序的数据倾斜问题处理

一、数据倾斜 什么是数据倾斜? 学过Redis集群的都知道数据倾斜这个问题。 就是大量数据,分配不均匀的现象。 二、MR数据倾斜 1、怎么判断出现数据倾斜? 数据频率倾斜——某一个区域的数据量要远远大于其他区域。 数据大小倾斜——部分记…

[iOS]浅析isa指针

[iOS]浅析isa指针 文章目录 [iOS]浅析isa指针isa指针isa的结构isa的初始化注意事项 上一篇留的悬念不止分类的实现 还有isa指针到底是什么 它是怎么工作的 class方法又是怎么运作的 class_data_bits_t bits; // class_rw_t * plus custom rr/alloc flags 这里面的class又是何方…

【BUG】已解决:Uncaught SyntaxError: Unexpected token ‘<‘

已解决:Could not install packages due to an EnvironmentError: [Errno 13] Permission denied 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉城市开发者社区主理人 …

用 WireShark 抓住 TCP

Wireshark 是帮助我们分析网络请求的利器,建议每个同学都装一个。我们先用 Wireshark 抓取一个完整的连接建立、发送数据、断开连接的过程。 简单的介绍一下操作流程。 1、首先打开 Wireshark,在欢迎界面会列出当前机器上的所有网口、虚机网口等可以抓取…

Nginx、LNMP万字详解

目录 Nginx 特点 Nginx安装 添加Nginx服务 Nginx配置文件 全局配置 HTTP配置 状态统计页面 Nginx访问控制 授权用户 授权IP 虚拟主机 基于域名 测试 基于IP 测试 基于端口 测试 LNAMP 解析方式 LNMP转发php-fpm解析 Nginx代理LAMP解析 LNMP部署示例 实…

github汉化

GitHub: Let’s build from here GitHub 进入后在顶部搜索栏直接搜索 选第一个点进去后根据教程安装即可 注意如果是chrome浏览器的话 安装篡改猴如果这个问题参考这个链接 chrome浏览器安装篡改猴脚本不执行(Please enable developer mode to allow userscript…

数据结构(单链表算法题)

1.删除链表中等于给定值 val 的所有节点。 OJ链接 typedef struct ListNode ListNode;struct ListNode {int val;struct ListNode* next; };struct ListNode* removeElements(struct ListNode* head, int val) {//创建新链表ListNode* newhead, *newtail;newhead newtail N…

IDEA工具中Java语言写小工具遇到的问题

一:读取excel时遇到 org/apache/poi/ss/usermodel/WorkbookProvider 解决办法: 在pom.xml中把poi的引文包放在最前面即可

小技巧:通过命令行和网络连接获取电脑所连wifi密码

方法一:命令行 第一步,命令行输入下列命令,获取连接过的wifi netsh wlan show profile 第二步,输入以下命令查看你想看的wifi密码(红色改为你获取的任意一个wifi名称) netsh wlan show profile happy key…

高数知识补充----矩阵、行列式、数学符号

矩阵计算 参考链接:矩阵如何运算?——线性代数_矩阵计算-CSDN博客 行列式计算 参考链接:实用的行列式计算方法 —— 线性代数(det)_det线性代数-CSDN博客 参考链接:行列式的计算方法(含四种,…