昇思25天学习打卡营第1天|快速入门-实现一个简单的深度学习模型

目录

实验环境

Jupyter云上开发环境使用

导包

 处理数据集

网络构建

 模型训练

评估模型性能

保存模型

加载模型

预测推理


实验环境

02-快速入门.ipynb (4) - JupyterLab (mindspore.cn)

规格:4u 16G 20G

镜像:py39-ms2.3.0rc1

特性:预装mindspore(2.3.0rc1)、numpy、pandas等依赖

Jupyter云上开发环境使用

1. %%capture captured_output: 

是一个在Jupyter Notebook或JupyterLab中使用的IPython魔术命令(magic command),用于捕获代码单元(cell)的输出。

%%capture 是一个单元格魔术命令(cell magic),它以两个百分号开头,并作用于整个单元格。captured_output 是你选择的变量名,用于存储捕获的输出。

2. pip install -i URL:

pip install -i 命令中的 -i 参数后面通常跟着一个URL,这个URL指向一个Python包索引(PyPI的镜像或自定义的索引服务器)。这个参数允许你指定pip从哪个源安装Python包,而不是默认的PyPI(Python Package Index)。

使用 -i 参数的好处:速度提升、稳定性增强、安全性

%%capture captured_output
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

导包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

处理数据集

MindSpore提供基于Pipeline的数据引擎,通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。使用Mnist数据集,使用mindspore.dataset提供的数据变换进行预处理。

1. 下载数据集

# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

2. 获得数据集对象

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

3. 打印数据集中包含的数据列名,用于dataset的预处理

print(train_dataset.get_col_names())# ['image', 'label']

4. 处理数据集

MindSpore的dataset使用数据处理流水线(Data Processing Pipeline),需指定map、batch、shuffle等操作。

使用map对图像数据及标签进行变换处理,然后将处理好的数据集打包为大小为64的batch。

def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

5. 查看数据和标签的shape和datatype

 可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问

for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break# Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32
# Shape of label: (64,) Int32# for data in test_dataset.create_dict_iterator():
#     print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
#     print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
#     break

网络构建

mindspore.nn类是构建所有网络的基类,也是网络的基本单元。导包过程中已导入。

当用户需要自定义网络时,可以继承nn.Cell类,并重写__init__方法和construct方法。

__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)# Network<
#  (flatten): Flatten<>
#  (dense_relu_sequential): SequentialCell<
#    (0): Dense<input_channels=784, output_channels=512, has_bias=True>
#    (1): ReLU<>
#    (2): Dense<input_channels=512, output_channels=512, has_bias=True>
#    (3): ReLU<>
#    (4): Dense<input_channels=512, output_channels=10, has_bias=True>
#    >
#  >

模型训练

 在模型训练中,一个完整的训练过程(step)需要实现以下三步:

  1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  3. 参数优化:将梯度更新到参数上。

MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:

  1. 定义正向计算函数。
  2. 使用value_and_grad通过函数变换获得梯度计算函数。
  3. 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
# Instantiate loss function and optimizer
# 使用优化器(optimizers)来更新网络中的权重,而随机梯度下降(Stochastic Gradient Descent, SGD)是其中最常见的一种优化算法。
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

评估模型性能

定义测试函数,用来评估模型的性能 

def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。

epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")

执行结果

保存模型

模型训练完成后,将其参数进行保存

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型

加载保存的权重分为两步:

1. 重新实例化模型对象,构造模型。
2. 加载模型参数,并将其加载至模型上。

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)# param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

预测推理

加载后的模型可以直接用于预测推理

model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3268553.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机网络】IP分片实验

一&#xff1a;实验目的 1&#xff1a;理解IP数据报分片的工作原理。 2&#xff1a;理解IP协议报文类型和格式。 二&#xff1a;实验仪器设备及软件 硬件&#xff1a;RCMS-C服务器、网线、Windows 2019/2003操作系统的计算机等。 软件&#xff1a;记事本、WireShark、Chrom…

Pc端vue2实现横向纵向鼠标滚动布局

类似uniaapp中的scroll-view组件,可横向可竖向,样式需要自己跳整一下 横向:(鼠标按下滑动里面的元素,可滑动,滚动条和左右都可以调整) 纵向: 代码实现:主页面引入组件 <template><div><!-- 调用组件 --><!-- vertical 垂直 写宽高 例如: widt…

失业潮下,有人靠天工AI做副业年入10万?

前言 你好&#xff0c;我是咪咪酱 这篇文章总结2个AI副业项目&#xff0c;不用写代码&#xff0c;就能做的2个副业项目。 第一&#xff1a;AI生成微信表情包&#xff0c;上传到微信表情包平台等&#xff0c;坚持下去&#xff0c;会有可观的收入。 第二&#xff1a;AI生成连载…

MQ消息队列+Lua 脚本实现异步处理下单流程

具体实现和代码可参考我以前做过的笔记&#xff1a;《黑马点评》异步秒杀优化|消息队列 回顾一下下单流程&#xff1a; 用户发起请求 会先请求Nginx,Nginx反向代理到Tomcat&#xff0c;而Tomcat中的程序&#xff0c;会进行串行工作&#xff0c; 分为以下几个操作&#xff1…

KamaCoder 98. 所有可到达路径 + LC 797. All Paths From Source to Target

题目要求 给定一个有 n 个节点的有向无环图&#xff0c;节点编号从 1 到 n。请编写一个函数&#xff0c;找出并返回所有从节点 1 到节点 n 的路径。每条路径应以节点编号的列表形式表示。 输入描述 第一行包含两个整数 N&#xff0c;M&#xff0c;表示图中拥有 N 个节点&…

Apache Nifi挂接MQTT与Kafka实践

目录 1. 说明&#xff1a; 2. 方案设计&#xff1a; 2.1 资源配置&#xff1a; 2.2 交互Topics: 3. 实现步骤 3.1 Nifi 桌面 3.2 MqttToKafka 3.2.1 配置 3.2.2 测试 3.2.3 结果 3.3 KafkaToMqtt 3.3.1 配置 3.3.1 测试 3.3.1 结果 ​编辑 4. 总结&#xff…

HAL STM32 SPI/ABZ/PWM方式读取MT6816磁编码器数据

HAL STM32 SPI/ABZ/PWM方式读取MT6816磁编码器数据 &#x1f4da;MT6816相关资料&#xff08;来自商家的相关资料&#xff09;&#xff1a; 资料&#xff1a;https://pan.baidu.com/s/1CAbdLBRi2dmL4D7cFve1XA?pwd8888 提取码&#xff1a;8888&#x1f4cd;驱动代码编写&…

计科录取75人!常州大学计算机考研考情分析!

常州大学&#xff08;Changzhou University&#xff09;&#xff0c;简称“常大”&#xff0c;位于江苏省常州市&#xff0c;是江苏省人民政府与中国石油天然气集团有限公司、中国石油化工集团有限公司及中国海洋石油集团有限公司共建的省属全日制本科院校&#xff0c;为全国深…

repo中的default.xml文件project name为什么一样?

文章目录 default.xml文件介绍为什么 name 是一样的&#xff0c;path 不一样&#xff1f;总结 default.xml文件介绍 在 repo 工具的 default.xml 文件中&#xff0c;定义了多个 project 元素&#xff0c;每个元素都代表一个 Git 仓库。 XML 定义了多个不同的 project 元素&…

Vue常用指令及其生命周期

作者&#xff1a;CSDN-PleaSure乐事 欢迎大家阅读我的博客 希望大家喜欢 目录 1.常用指令 1.1 v-bind 1.2 v-model 注意事项 1.3 v-on 注意事项 1.4 v-if / v-else-if / v-else 1.5 v-show 1.6 v-for 无索引 有索引 生命周期 定义 流程 1.常用指令 Vue当中的指令…

【MySQL进阶篇】锁:全局锁、表级锁以及行级锁

一、锁的概述 锁是计算机协调多个进程或线程并发访问某一资源的机制。在数据库中除传统的计算资源&#xff08;CPU、RAM、I/O&#xff09;的争用以外&#xff0c;数据也是一种供许多用户共享的资源。如何保证数据并发访问的一致性、有效性是所有数据库必须要解决的一个问题&am…

vs2019配置MySQL记录

vs2019配置MySQL记录 一、安装MySQL 参考&#xff1a;MySQL5.5.19的安装步骤 基本上就是一路默认安装就行。 二、验证 左下角打开MySQL 输入秘密能看到如下界面&#xff0c;即表示MySQL安装成功 三、安装vs2019的MySQL驱动 这里主要参考&#xff1a;Visual Studio 201…

MySQL练习03

题目 步骤 创建数据库 create database mydb11_stu; #创建 use mydb11_stu; #使用 创建表 student表 create table student( id int(10) not null unique primary key, name varchar(20) not null, sex varchar(4),birth year, department varchar(20), address var…

AR 眼镜之-充电动画定制-实现方案

目录 &#x1f4c2; 前言 AR 眼镜系统版本 充电动画 1. &#x1f531; 技术方案 1.1 方案介绍 1.2 实现方案 关机充电动画 亮屏/锁屏充电动画 2. &#x1f4a0; 关机充电动画 2.1 关机充电动画核心处理类与路径 2.2 实现细节 步骤一&#xff1a;1&#xff09;定制 …

从零开始学习网络安全渗透测试之基础入门篇——(二)Web架构前后端分离站Docker容器站OSS存储负载均衡CDN加速反向代理WAF防护

Web架构 Web架构是指构建和管理Web应用程序的方法和模式。随着技术的发展&#xff0c;Web架构也在不断演进。当前&#xff0c;最常用的Web架构包括以下几种&#xff1a; 单页面应用&#xff08;SPA&#xff09;&#xff1a; 特点&#xff1a;所有用户界面逻辑和数据处理都包含…

VSCode切换默认终端

我的VSCode默认终端为PowerShell&#xff0c;每次新建都会自动打开PowerShell。但是我想让每次都变为cmd&#xff0c;也就是Command Prompt 更改默认终端的操作方法如下&#xff1a; 键盘调出命令面板&#xff08;CtrlShiftP&#xff09;中,输入Terminal: Select Default Prof…

【记忆化搜索】【超详细】力扣3186. 施咒的最大总伤害

一个魔法师有许多不同的咒语。 给你一个数组 power &#xff0c;其中每个元素表示一个咒语的伤害值&#xff0c;可能会有多个咒语有相同的伤害值。 已知魔法师使用伤害值为 power[i] 的咒语时&#xff0c;他们就 不能 使用伤害为 power[i] - 2 &#xff0c;power[i] - 1 &…

记录安装android studio踩的坑 win7系统

最近在一台新电脑上安装android studio,报了很多错误&#xff0c;也是费了大劲才解决&#xff0c;发出来大家一起避免一些问题&#xff0c;找到解决方法。 安装时一定要先安装jdk&#xff0c;cmd命令行用java -version查当前的版本&#xff0c;没有的话&#xff0c;先安装jdk,g…

地形材质制作(能使地面湿润)

如图&#xff0c;创建一个材质并写以下逻辑 Landscape Layer Blend节点能使在地形模式绘制中有三个选择&#xff0c;根据以上逻辑&#xff0c;Red是原材质,Green是绿色材质也就是草&#xff0c;Blue为水&#xff08;这个我认为比较重要&#xff09; Blue的颜色最好为这个 这个节…

QEMU源码全解析 —— CPU虚拟化(11)

接前一篇文章: 本文内容参考: 《趣谈Linux操作系统》 —— 刘超,极客时间 《QEMU/KVM》源码解析与应用 —— 李强,机械工业出版社 《深度探索Linux系统虚拟化原理与实现》—— 王柏生 谢广军, 机械工业出版社 特此致谢! 前边几回又再次讲了一下VMX,本回开始讲解VCPU…