【深度学习实战(33)】训练之model.train()和model.eval()

一、model.train(),model.eval()作用?

model.train() 和 model.eval() 是 PyTorch 中的两个方法,用于设置模型的训练模式和评估模式。

model.train() 方法将模型设置为训练模式。在训练模式下,模型会启用 dropout 和 batch normalization 等正则化方法,并且可以计算梯度以进行参数更新,同时还可以追踪梯度计算的图。训练时,均值、方差分别是该批次内数据相应维度的均值与方差

model.eval() 方法将模型设置为评估模式。在评估模式下,模型会禁用 dropout 和 batch normalization 等正则化方法,这样可以保证每次评估的结果是确定的。评估模式下的模型通常用于模型的测试、验证或推理阶段。推理时,均值、方差是基于所有批次的期望计算所得

区分训练模式和评估模式的目的在于保证模型在不同阶段的行为一致性。例如,在训练模式下,模型需要计算并追踪梯度以进行反向传播和参数更新;而在评估模式下,模型不需要计算梯度,只需要给出确定的预测结果。

二、model.train(),model.eval()对dropout产生的影响

使用model.train():有神经元被置零,且比例符合nn.Dropout(0.5)中的0.5设定

import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
model.train()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述
使用model.eval():没有神经元置零,nn.Dropout(0.5)被关闭

import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
#model.train()
model.eval()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

不使用model.train()和model.eval():有神经元被置零,但是比例非常随机,不符合nn.Dropout(0.5)中的0.5设定
import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
#model.train()
#model.eval()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

三、model.train(),model.eval()对batch normalization产生的影响

使用model.eval():bn中的均值,方差,不发生改变

# 1.导入所需的库:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# 2.定义数据集的转换方法。MNIST数据集是由28x28像素的手写数字组成的图像,将其转换为torch张量并进行标准化处理:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 3.下载MNIST数据集并进行转换:
trainset = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,download=True, transform=transform)# 4.创建数据加载器:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False, num_workers=0)# 5.现在你可以使用trainloader和testloader来获取训练集和测试集的批次数据了。例如,可以使用迭代器遍历数据集中的批次:
#dataiter = iter(trainloader)
#images, labels = dataiter.next()# 上述代码将返回一个批次的图像和对应的标签。可以使用images和labels来进行模型的训练和评估。
# 这就是使用torch库自带的MNIST数据集的基本流程。根据需要,你还可以添加其他的数据处理和增强步骤。# 定义模型
class Model(nn.Module):def __init__(self, hidden_num=32, out_num=10):super().__init__()self.fc1 = nn.Linear(28*28, hidden_num)self.bn  = nn.BatchNorm1d(hidden_num)self.fc2 = nn.Linear(hidden_num, out_num)self.softmax = nn.Softmax()def forward(self, inputs, **kwargs):x = inputs.flatten(1)x = self.fc1(x)print("========= bn之前存的数据: =========")print(self.bn.running_mean, self.bn.running_var)print()print("========= 当前 Batch 的数据: =========")x_mean = torch.mean(x,0)x_variance = torch.mean((x - x_mean)*(x - x_mean),0)print(x_mean, x_variance)print()print("========= torch官方计算之后的bn新数据: =========")x = self.bn(x)print(self.bn.running_mean, self.bn.running_var)print()# x = self.dropout(x)x = self.fc2(x)x = self.softmax(x)return xtorch.manual_seed(1)
model = Model()
#model.train()
model.eval()
for img, label in trainloader:label = nn.functional.one_hot(label.flatten(), 10)out = model(img)break

在这里插入图片描述
使用model.train():bn中的均值,方差,通过滑动平均地方式发生改变,

torch.manual_seed(1)
model = Model()
model.train()
#model.eval()
for img, label in trainloader:label = nn.functional.one_hot(label.flatten(), 10)out = model(img)break

在这里插入图片描述
不使用model.train()和model.eval():默认bn中的均值,方差,通过滑动平均地方式发生改变,
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3015859.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

python爬虫学习------scrapy第三部分(第三十一天)

🎈🎈作者主页: 喔的嘛呀🎈🎈 🎈🎈所属专栏:python爬虫学习🎈🎈 ✨✨谢谢大家捧场,祝屏幕前的小伙伴们每天都有好运相伴左右,一定要天天…

ttkbootstrap界面美化系列之PanedWindow(七)

在界面设计中经常用PanedWindow控件来对整个界面进行切割布局,让整个界面看上去有层次感,不至于说杂乱无章。在我之前的文章中有对tkinter的该控件做了详细的介绍,链接如下基于Tkinter的PanedWindow组件进行窗口布局-CSDN博客 本文主要是介绍…

MapReduce的Shuffle过程

Shuffle是指从 Map 产生输出开始,包括系统执行排序以及传送Map输出到Reduce作为输入的过程. Shuffle 阶段可以分为 Map 端的 Shuffle 阶段和 Reduce 端的 Shuffle 阶段. Shuffle 阶段的工作过程,如图所示: Map 端的 Shuffle 阶段 1)每个输入分片会让一个 Map 任务…

Linux学习之路 -- 文件 -- 文件描述符

前面介绍了与文件相关的各种操作&#xff0c;其中的各个接口都离不开一个整数&#xff0c;那就是文件描述符&#xff0c;本文将介绍文件描述符的一些相关知识。 目录 <1>现象 <2>原理 文件fd的分配规则和利用规则实现重定向 <1>现象 我们可以先通过prin…

04_SpringCloud

文章目录 单体架构与微服务架构的介绍单体架构微服务架构 微服务的实现服务之间的调用服务注册中心Eureka 注册中心Eureka的自我保护机制Nacos注册中心 单体架构与微服务架构的介绍 单体架构 单体架构 所有的代码最终打包成一个文件(jar包)&#xff0c;整个系统的所有功能单元…

爆赞好文之java反序列化之CB超详细易懂分析

java反序列化之CB超详细易懂分析 CB1环境搭建前言分析PropertyUtilsBeanComparatorPriorityQueue CB2环境搭建前言exp CB1 环境搭建 pom.xml <dependencies><dependency><groupId>commons-beanutils</groupId><artifactId>commons-beanutils&l…

生成gitee公钥

1、打开设置 2、设置SSH公钥 3、生成公钥 4、复制终端输出的公钥&#xff0c;放到这里&#xff0c;标题随便取。 5、测试 ssh -T gitgitee.com 最后用这个测试

IT 项目管理介绍和资料汇总

IT项目管理到底是什么&#xff1f;是对组织承担的任何信息技术项目的成功监督。IT项目经理负责规划、预算、执行、领导、故障排除和维护这些项目。IT项目经理可能会做的事情包括&#xff1a; 1、硬件安装 2、软件、网站和应用程序开发 3、网络和云计算解决方案的升级和/或推出…

软考中级-软件设计师(八)算法设计与分析 考点最精简

一、算法设计与分析的基本概念 1.1算法 算法&#xff08;Algorithm&#xff09;是对特定问题求解步骤的一种描述&#xff0c;有5个重要特性&#xff1a; 有穷性&#xff1a;一个算法必须总是在执行又穷步骤后结束&#xff0c;且每一步都可在又穷时间内完成 确定性算法中每一…

KAN: Kolmogorov–Arnold Networks

KAN: Kolmogorov–Arnold Networks 论文链接&#xff1a;https://arxiv.org/abs/2404.19756 代码链接&#xff1a;https://github.com/KindXiaoming/pyKAN 项目链接&#xff1a;https://kindxiaoming.github.io/pyKAN/intro.html Abstract 受Kolmogorov-Arnold表示定理的启…

访问网络附加存储:nfs

文章目录 访问网络附加存储一、网络附加存储1.1、存储类型1.3、通过NFS挂载NAS1.4、NFS挂载过程服务端客户端 二、实验&#xff1a;搭建NFS服务端及挂载到nfs客户端服务端客户端测试命令合集服务端客户端 访问网络附加存储 一、网络附加存储 1.1、存储类型 DAS&#xff1a;Di…

mysql的数据结构及索引使用情形

先来说下数据的一般存储方式&#xff1a;内存(适合小数据量)、磁盘(大数据量)。 磁盘的运转方式&#xff1a;速度 旋转&#xff0c;磁盘页的概念&#xff1a;每一页大概16KB。 1、存储结构 哈希 是通过hash函数计算出一个hash值的&#xff0c;哈希的优点就是查找的时间复杂度…

图片编辑工具-Gimp

一、前言 GIMP&#xff08;GNU Image Manipulation Program&#xff09;是一款免费开源的图像编辑软件&#xff0c;具有功能强大和跨平台的特性。 GIMP作为一个图像编辑器&#xff0c;它提供了广泛的图像处理功能&#xff0c;包括但不限于照片修饰、图像合成以及创建艺术作品…

SpringMVC响应数据

三、SpringMVC响应数据 3.1 handler方法分析 理解handler方法的作用和组成&#xff1a; /*** TODO: 一个controller的方法是控制层的一个处理器,我们称为handler* TODO: handler需要使用RequestMapping/GetMapping系列,声明路径,在HandlerMapping中注册,供DS查找!* TODO: ha…

【notepad++】使用

1 notepad 下载路径 https://notepad-plus.en.softonic.com/download 2 设置护眼模式 . 设置——语言格式设置——前景色——黑色 . 背景色——RGB &#xff1a;199 237 204 . 勾选“使用全局背景色”、“使用全局前景色” . 保存并关闭

Python专题:二、Python小游戏,体验Python的魅力

希望先通过一个小的游戏让大家先对Python感兴趣&#xff0c;兴趣是最好的老师。 小游戏的运行结果&#xff1a; 1、在sublime编辑器里面写如下代码&#xff1a; import randomnum random.randint(1, 100) # 获得一个随机数 is_done False # 是否猜中的标记 count 0 # 玩…

软件设计师-应用技术-数据结构及算法题4

考题形式&#xff1a; 第一题&#xff1a;代码填空 4-5空 8-10第二题&#xff1a;时间复杂度 / 代码策略第三题&#xff1a;拓展&#xff0c;跟一组数据&#xff0c;把数据带入代码中&#xff0c;求解 基础知识及技巧&#xff1a; 1. 分治法&#xff1a; 基础知识&#xff1…

美易官方:英伟达业绩将难以撑起股价?

美股市场似乎总是对各大公司的业绩表现抱有极大的期待&#xff0c;就像一个永远填不饱的“巨胃”。在这样的市场环境下&#xff0c;即使是业绩骄人的公司也可能难以支撑其股价。英伟达&#xff0c;这家在图形处理单元&#xff08;GPU&#xff09;领域享有盛誉的公司&#xff0c…

语音识别--kNN语音指令识别

⚠申明&#xff1a; 未经许可&#xff0c;禁止以任何形式转载&#xff0c;若要引用&#xff0c;请标注链接地址。 全文共计3077字&#xff0c;阅读大概需要3分钟 &#x1f308;更多学习内容&#xff0c; 欢迎&#x1f44f;关注&#x1f440;【文末】我的个人微信公众号&#xf…

英语学习笔记5——Nice to meet you.

Nice to meet you. 很高兴见到你。 词汇 Vocabulary Mr. 先生 用法&#xff1a;自己全名 / 姓 例如&#xff1a;Mr. Zhang Mingdong 或 Mr. Zhang&#xff0c;绝对不能是 Mr. Mingdong&#xff01; Miss 女士&#xff0c;小姐 未婚 用法&#xff1a;自己全名 / 姓 例如&#…