昇思学习打卡-12-Vision Transformer图像分类

文章目录

  • ViT模型学习
  • 构建模型
    • Multi-Head Attention
    • TransformerEncoder
    • pos_embedding
    • Vit部分实现
  • 推理结果

ViT模型学习

Vision Transformer(ViT)简介
ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型特点
ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  • 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  • 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  • 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

构建模型

Multi-Head Attention

  • 多头注意力可以实现并行加速,且可以保持参数量
from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

TransformerEncoder

class TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

pos_embedding

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

Vit部分实现

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

推理结果

在这里插入图片描述
此章节学习到此结束,感谢昇思平台。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3226846.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

最简单详细的jwt用户登录校验教程(新手必看)

首先简单建张用户表。 DROP TABLE IF EXISTS user; CREATE TABLE user (id bigint NOT NULL AUTO_INCREMENT,name varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL,username varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL…

C++ 编译体系入门指北

前言 之从入坑C之后,项目中的编译构建就经常跟CMake打交道,但对它缺乏系统的了解,遇到问题又陷入盲人摸象。对C的编译体系是如何发展的,为什么要用CMake,它的运作原理是如何的比较感兴趣,所以就想系统学习…

云手机批量操作使用场景,从Amazon、TK等软件分析

云手机目前所具备的群控,批量操作,自动化等功能,对于电商,软测,办公,直播,营销等行业有很好的减负作用。 针对于具体的海外APP,云手机具体可以做哪些事情来帮助我们减轻压力&#x…

伺服【禾川X6】

驱动器: A:脉冲 B:EtherCAT // SV-X6 FB 040 AA 一套360 N:CANopen R:PROFINET 电机: SV-X6 MA 040A-B2 KA

MongoDB - 集合和文档的增删改查操作

文章目录 1. MongoDB 运行命令2. MongoDB CRUD操作1. 新增文档1. 新增单个文档 insertOne2. 批量新增文档 insertMany 2. 查询文档1. 查询所有文档2. 指定相等条件3. 使用查询操作符指定条件4. 指定逻辑操作符 (AND / OR) 3. 更新文档1. 更新操作符语法2. 更新单个文档 updateO…

Linux--线程ID封装管理原生线程

目录 1.线程的tid(本质是线程属性集合的起始虚拟地址) 1.1pthread库中线程的tid是什么? 1.2理解库 1.3phtread库中做了什么? 1.4线程的tid,和内核中的lwp 1.5线程的局部存储 2.封装管理原生线程库 1.线程的tid…

odoo视图继承

odoo视图继承 在模型时候,不对视图、菜单等进行修改,原视图和菜单等视图数据仍然可以使用,不需要重新构建 form视图继承案例 model:为对应模型 inherit_id:为继承的视图,ref:为继承视图的id&#xff0…

帝特(DTECH)USB转RS485/422串口线在Ubuntu系统中的安装

因为测试需要,买了一根帝特(DTECH)USB转RS485/422串口线,今天测试了一下在Ubuntu 22.04系统上的使用。帝特的网站上提供了驱动程序,下载以后发现接口芯片是CP2102,厂商只提供了Linux内核2.6和3.x版本的驱动…

新版FMEA培训未能达到预期效果怎么办?

在制造业的质量管理中,FMEA(Failure Mode and Effects Analysis,失效模式与影响分析)是一项至关重要的工具,它帮助企业识别和评估产品或过程中潜在的失效模式,以及这些失效模式可能导致的后果。然而&#x…

AIGC技术引领创意设计行业革新,“谁”能成职业发展新引擎?

随着科技的日新月异,生成式人工智能(AIGC)技术正迅速崛起,成为创意设计领域的一股强大新势力。该技术不仅显著提升了设计师的工作效率,更为他们打开了前所未有的创意空间。在这一波技术浪潮中,Adobe国际认证…

AutoMQ 与蚂蚁数科达成战略合作

近期,AutoMQ 与蚂蚁数科正式签署战略合作协议,将和蚂蚁数科云原生 PaaS 平台 SOFAStack 在产品研发、生态集成、市场合作、技术社区影响力等多方面开展深度合作。 AutoMQ 是业内领先的消息和流存储服务提供商,基于云原生基础设施重新设计了 …

如何整合生成的人工智能?(GenAI)为你未来的工作增加动力

生成人工智能(GenAI)它发展迅速,以前所未有的速度取得了突破。人工智能将继续改变各行各业,预计2023年至2030年的年增长率将达到37.3%。由于一种新的知识工作者现在面临被取代的风险,生成式人工智能的惊人崛起进一步加剧了这种紧迫性。据《未…

虚拟机内安装vue-dev-tools

前言 项目开发调试都需要在Citrix在虚拟机环境下,Citrix内连接不到外网,在这边文章,我将介绍自己在Citrix环境内安装 vue-dev-tools的经验 环境 vue 步骤 1. 下载.crx文件 百度网盘里的 .crx文件的 下载链接 2. 加载.crx文件 打开浏览…

即时通讯平台项目测试(主页面)

http://8.130.98.211:8080/login.html项目访问地址:即时通讯平台http://8.130.98.211:8080/login.html 本篇文章进行项目主页面的测试。 在测试前需要先对待测内容进行分类,按照功能进行分类可以分为:个人信息设置、发送/接收消息、添加好友…

领夹麦克风哪个品牌好,哪个麦克风好,热门无线麦克风品牌推荐

​无线领夹麦克风是现代沟通的重要工具,它不仅提高了语音交流的清晰度,还展现了使用者的专业形象。随着技术发展,这些麦克风已经变得更加轻便、时尚,易于使用。在各种场合,如演讲、教育和网络直播中,当然&a…

请跳至打印机属性的“Adobe PDF设置”页面,取消选择“仅停靠系统字体;不使用文档字体”

场景: 当使用adobe pdf打印时,出现如下提示“请跳至打印机属性的“Adobe PDF设置”页面,取消选择“仅停靠系统字体;不使用文档字体””,该如何解决。 描述 □“仅停靠系统字体;不使用文档字体” 复选本框…

Git常见命令和用法

Git 文件状态 Git 文件 2 种状态: 未跟踪:新文件,从未被 Git 管理过已跟踪:Git 已经知道和管理的文件 常用命令 命令作用注意git -v查看 git 版本git init初始化 git 仓库初始化之后有工作区、暂存区(本地库)、版本库git add 文件标识暂存某个文件文件标识以终…

如何分析软件测试中发现的Bug!

假如你是一名软件测试工程师,每天面对的就是那些“刁钻”的Bug,它们像是隐藏在黑暗中的敌人,时不时跳出来给你一个“惊喜”。那么,如何才能有效地分析和处理这些Bug,让你的测试工作变得高效且有趣呢?今天我…

Python的语言特性

1,python是动态语言 在编译期间就确定变量类型的语言是静态语言 在运行期间才知道变量类型的是动态语言 2,python是强类型语言 不同类型的变量是否允许隐式转换

院内导航:如何用科技破解就医找路难题

自2019年开始“院内导航”被纳入医院智慧服务评估体系以来,到2023年改善就医服务升级的部署,每一步都见证了我国医疗卫生体系向智能化、人性化迈进的坚实步伐。 面对庞大复杂的医院环境与日益增长的就诊需求,如何让患者在茫茫人海中迅速找到就…