《昇思25天学习打卡营第6天|ResNet50图像分类》

写在前面

从本次开始,接触一些上层应用。
本次通过经典的模型,开始本次任务。这里开始学习resnet50网络模型,应该也会有resnet18,估计18的模型速度会更快一些。

resnet

通过对论文的结论进行展示,说明了模型的功能,解决了卷积网络层数加大后模型的退化问题。20层和56层相比,层数越大,模型效果越差,因此resnet主要解决这种问题。hekaiming是真的强呀。

基本流程

  1. 整理模型数据
  2. 构建模型网络核心逻辑(ResidualBlockBase/ResidualBlock)
  3. 创建模型一层

构建网络的代码

from typing import Type, Union, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal# 初始化卷积层与BatchNorm的参数
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out

创建模型一层

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)

创建模型

搭建一个4层的网络。

from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return x

接下来,连接数据和模型网络,开始构建容易使用的网络。在这里设置了,模型残差的方法和每个block。

def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型download(url=model_url, path=pretrained_ckpt, replace=True)param_dict = load_checkpoint(pretrained_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"""ResNet50模型"""resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

模型训练和评估

并没有完全训练,使用了预训练的方法,下载了预训练的模型。

# 定义ResNet50网络
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=10)
# 重置全连接层
network.fc = fc

有了模型网络,接下来需要进行模型训练。训练的过程要设置学习率、优化器和损失函数。

# 设置学习率
num_epochs = 1
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,step_per_epoch=step_size_train, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss

之后进行多个epoch的迭代,实现模型训练的目标。

import mindspore.ops as opsdef train(data_loader, epoch):"""模型训练"""losses = []network.set_train(True)for i, (images, labels) in enumerate(data_loader):loss = train_step(images, labels)if i % 100 == 0 or i == step_size_train - 1:print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %(epoch + 1, num_epochs, i + 1, step_size_train, loss))losses.append(loss)return sum(losses) / len(losses)def evaluate(data_loader):"""模型验证"""network.set_train(False)correct_num = 0.0  # 预测正确个数total_num = 0.0  # 预测总数for images, labels in data_loader:logits = network(images)pred = logits.argmax(axis=1)  # 预测结果correct = ops.equal(pred, labels).reshape((-1, ))correct_num += correct.sum().asnumpy()total_num += correct.shape[0]acc = correct_num / total_num  # 准确率return acc# 开始循环训练
print("Start Training Loop ...")for epoch in range(num_epochs):curr_loss = train(data_loader_train, epoch)curr_acc = evaluate(data_loader_val)print("-" * 50)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, curr_loss, curr_acc))print("-" * 50)# 保存当前预测准确率最高的模型if curr_acc > best_acc:best_acc = curr_accms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)

进行多轮训练之后,达到训练的目的,模型开始进行收敛,并且能够获取到最终的结果。

最后进行评估,这个并不复杂。

打开

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3267753.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu 22.04如何设置中文输入法

前言 近期整理了一下之前在ubuntu 22.04 中如何设置中文输入法的过程,对于本人比较适应读中文写中文来说,这是我安装后的第一步。 一、流程 1.1 安装中文语言包(如果还未安装) 首先是安装中文语言包,直接在终端输入…

全能Ai助手:写作到设计,宝藏神器帮你事半功倍

今天,就让我们一起踏上这场寻找“隐藏”宝藏的旅程,探索这些AI工具如何改变我们的生活! 一、高效生产力的提升之道 1. 文案创作助手 案例:某位自媒体博主使用了一款智能写作工具,不仅大大节省了写作时间,…

数据库(MySQL)-视图、存储过程、触发器

一、视图 视图的定义、作用 视图是从一个或者几个基本表(或视图)导出的表。它与基本表不同,是一个虚表。但是视图只能用来查看表,不能做增删改查。 视图的作用:①简化查询 ②重写格式化数据 ③频繁访问数据库 ④过…

【React学习打卡第四天】

ReactRouter 一、概念二、创建路由开发环境三、快速开始四、抽象路由模块实际开发中的router配置 五、路由导航1.声明式导航2.编程式导航 六、路由导航传参1.searchParams 传参2.params 传参 七、嵌套路由配置八、默认二级路由九、404路由配置十、俩种路由模式 一、概念 一个路…

使用Python爬虫采集亚马逊新品榜商品数据

一、引言 1.1 亚马逊新品榜的重要性 亚马逊是全球最大的电商平台之一,亚马逊新品榜展示了最新上架并受欢迎的产品。对于电商卖家和市场分析师来说,了解这些新品榜单可以帮助他们捕捉市场趋势,了解消费者喜好,从而优化产品策略和营…

视频怎么加密?常见的四种视频加密方法和软件

视频加密是一种重要的技术手段,用于保护视频内容不被未经授权的用户获取、复制、修改或传播。在加密过程中,安企神软件作为一种专业的加密工具,可以发挥重要作用。 以下将详细介绍如何使用安企神软件对视频进行加密,并探讨视频加密…

block_size设置过大错误分析(查看CUDA设备线程块大小)

block_size设置过大错误分析(查看CUDA设备线程块大小) 1 问题描述2 问题分析3 解决方法4 调试和验证5 查看设备线程块大小 1 问题描述 本人作为CUDA编程初学者,在学习编写使用CUDA计算矩阵相乘代码时发现,如果我的 block_size &g…

可能是最好的工具网站

前些苏音在刷视频,发现了一堆好用的宝藏网站,这就赶快分享给大家。 工具网站 这个网站类似于网址导航,集合了包括工具类、资源类、软件类、AI类的合集 并且站长表示励志做体验感最好的工具网,聚焦最快解决用户的需求 首先就是办…

数据库安全综合治理方案(可编辑54页PPT)

引言:数据库安全综合治理方案是一个系统性的工作,需要从多个方面入手,综合运用各种技术和管理手段,确保数据库系统的安全稳定运行。 方案介绍: 数据库安全综合治理方案是一个综合性的策略,旨在确保数据库系…

【8月EI会议推荐】第四届区块链技术与信息安全国际会议

一、会议信息 大会官网:http://www.bctis.nhttp://www.icbdsme.org/ 官方邮箱:icbctis126.com 组委会联系人:杨老师 19911536763 支持单位:中原工学院、西安工程大学、齐鲁工业大学(山东省科学院)、澳门…

一天搞定React(3)——Hoots组件【已完结】

Hello!大家好,今天带来的是React前端JS库的学习,课程来自黑马的往期课程,具体连接地址我也没有找到,大家可以广搜巡查一下,但是总体来说,这套课程教学质量非常高,每个知识点都有一个…

数据结构经典测试题4

1. #include <stdio.h> int main() { char *str[3] {"stra", "strb", "strc"}; char *p str[0]; int i 0; while(i < 3) { printf("%s ",p); i; } return 0; }上述代码运行结果是什么&#xff1f; A: stra strb strc B: s…

Rocky/Centos Linux安装Code-server,并注册成服务自启动

文章目录 Rocky/Centos Linux安装Code-server&#xff0c;并注册成服务自启动介绍安装1. 下载压缩包2. 解压缩3. 执行启动命令4. 浏览器访问5. 开机自启动 Rocky/Centos Linux安装Code-server&#xff0c;并注册成服务自启动 介绍 VS Code Server是微软推出的VSCode风格的Web…

谷歌AI拿下IMO奥数银牌!6道题轻松解出4道~

本周四&#xff0c;谷歌DeepMind团队宣布了一项令人瞩目的成就&#xff1a;&#xff1a;用 AI 做出了今年国际数学奥林匹克竞赛 IMO 的真题&#xff0c;并且距拿金牌仅一步之遥。这一成绩不仅标志着人工智能在数学推理领域的重大突破&#xff0c;也引发了全球范围内的广泛关注和…

私域电商丨软件系统开发中,一定要避开的几个坑,看懂少很多弯路

文丨微三云胡佳东&#xff0c;点击上方“关注”&#xff0c;为你分享市场商业模式电商干货。 - 大家好&#xff0c;我是软件开发胡佳东&#xff0c;每天为大家分享互联网资讯干货&#xff01; 在数字化时代的今天&#xff0c;软件开发是已经成为推动科技进步和商业发展的重要…

vmware虚拟机安装linux没有IP地址

直接设置固定IP 1、在虚拟机菜单栏选择编辑&#xff0c;然后点击虚拟网络编辑器 2、选择Vmnet8 Net网络连接方式&#xff0c;随意设置子网IP 3、点击NAT设置页面&#xff0c;查看子网掩码和网关&#xff0c;修改静态IP会用到 4、打开电脑控制面板–网络和Internet–网络连…

面试常考Linux指令

文件权限 操作系统中每个文件都拥有特定的权限、所属用户和所属组。权限是操作系统用来限制资源访问的机制&#xff0c;在 Linux 中权限一般分为读(readable)、写(writable)和执行(executable)&#xff0c;分为三组。分别对应文件的属主(owner)&#xff0c;属组(group)和其他用…

前端知识笔记之HTML

1.标签元素与属性&#xff0c;注意事项 2.多级标签排序List&#xff0c;无顺序&#xff08;Ul&#xff09;和有顺序(Ol) 3.HTML页面结构 4.页面跳转&#xff0c;注意#是统一页面的跳转 5.图片、视频、音频 标签 6.前端表单与后端方法 数据接收的demo 7.常见表单项 8.注意日期类…

Python爬虫知识体系-----Urllib库的使用

数据科学、数据分析、人工智能必备知识汇总-----Python爬虫-----持续更新&#xff1a;https://blog.csdn.net/grd_java/article/details/140574349 文章目录 1. 基本使用2. 请求对象的定制3. 编解码1. get请求方式&#xff1a;urllib.parse.quote&#xff08;&#xff09;2. ur…