pytorch-trainvaltest划分

目录

  • 1. 上一节回顾
  • 2. 数据集划分
  • 3. 完整代码

1. 上一节回顾

下列图中三种曲线分别代表了欠拟合、好的拟合和过拟合
在这里插入图片描述
下图为过拟合曲线,那么如何来检测过拟合呢?将数据集划分为train和val(validation)val是用来测试训练过程是否过拟合的。
在这里插入图片描述

2. 数据集划分

数据集一般划分为train、val(当划分为2个数据集时,val又被称为test)、test
以MINIST数据集为例,下图中将数据集化为为train和test(val)
在这里插入图片描述
下图中每train一个epoch就进行依次test看是否发生了过拟合,并保存当时的checkpoint,训练完成后选择性能最好的checkpoints即可。
在这里插入图片描述
下图中在标记点之后train的loss变换不明显,而test loss却升高了,这就是发生了overfitting过拟合。
在这里插入图片描述
如下图:假如又70k数据train50k,val10k,test10k,那么为什么要又test数据集呢?作用是什么呢?test数据集是用来验证模型的性能的,不能用于train和val否则会造成数据污染,也可称为作弊。
在这里插入图片描述
pytorch只能划分train和test,即通过train_db = datasets.MNIST('../data', train=True, download=True, 中的train=true即为train数据集,否则为test数据集 ,那么剩下的train我们要人为划分为train和val。
如下图所示:通过train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000]) 实现把train划分为train50k,val10k
在这里插入图片描述
总结:train数据用来训练,val数据用来检测训练是否过拟合的,test数据集是用来验证模型性能的

3. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_db = datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = torch.utils.data.DataLoader(train_db,batch_size=batch_size, shuffle=True)test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(test_db,batch_size=batch_size, shuffle=True)print('train:', len(train_db), 'test:', len(test_db))
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))
train_loader = torch.utils.data.DataLoader(train_db,batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_db,batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in val_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(val_loader.dataset)print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(val_loader.dataset),100. * correct / len(val_loader.dataset)))test_loss = 0
correct = 0
for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2980512.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

CSS 04

去掉 li 前面的 项目符号(小圆点) 语法 list-style: none;圆角边框 border-radius 属性用于设置元素的外边框圆角。 语法 border-radius:length;参数值可以为数值或百分比的形式如果是正方形,想要设置为一个圆,把数值修改为高度或者宽度的一半即可&a…

Opencv_11_通道的分离与合并

void ColorInvert::channels_demo(Mat& image) { std::vector<Mat> mv; split(image, mv); imshow("蓝色", mv[0]); imshow("绿色", mv[1]); imshow("红色", mv[2]); Mat dst; mv[0] 0; merge(mv, dst);…

【Camera KMD ISP SubSystem笔记】CRM V4L2驱动模型

1. CRM为主设备 /dev/video0&#xff0c;先创建 v4l2_device 设备&#xff0c;再创建 video_device 设备&#xff0c;最后创建 media_device 设备/dev/media0 v4l2_device的mdev指向media_device&#xff0c;v4l2_device的entity链接到media_device的entities上&#xff08…

WEB服务的配置与使用 Apache HTTPD

服务端&#xff1a;服务器将发送由状态代码和可选的响应正文组成的 响应 。状态代码指示请求是否成功&#xff0c;如果不成功&#xff0c;则指示存在哪种错误情况。这告诉客户端应该如何处理响应。较为流星的web服务器程序有&#xff1a; Apache HTTP Server 、 Nginx 客户端&a…

【debug记录】有gpu,但是 pytorch仍显示 cpu【原来是新电脑没安装cuda】

原来是新电脑没安装cuda&#xff0c;以为安装了pytorch包就可以了。 检查过程&#xff1a; nvcc 不是内部或外部命令&#xff0c;也不是可运行的程序, 说明没有安装cuda。 查看电脑显卡最高支持cuda版本&#xff1a;nvidia-smi 安装cuda&#xff0c;选择版本&#xff1a;ht…

Android Studio 报错:AVD Pixel_3a_API_30_x86 is already running

在我的Android Studio和虚拟机运行时&#xff0c;我的电脑不小心关机了&#xff0c;在启动后再次打开Android Studio并运行虚拟机时发现报错。 Error while waiting for device: AVD Pixel_3a_API_30_x86 is already running. If that is not the case, delete the files at C…

系统安全与应用(1)

目录 1、账号安全管理 &#xff08;1&#xff09;禁止程序用户登录 &#xff08;2&#xff09;锁定禁用长期不使用的用户 &#xff08;3&#xff09;删除无用的账号 &#xff08;4&#xff09;禁止账号和密码的修改 2、密码安全管理 设置密码有效期 1&#xff09;针对已…

《ElementPlus 与 ElementUI 差异集合》el-select 差异点,如:高、宽、body插入等

宽度 Element UI 父元素不限制宽度时&#xff0c;默认有个宽度 207px&#xff1b; 父元素有固定宽度时&#xff0c;以父元素宽度为准&#xff1b; Element Plus 父元素不限制宽度时&#xff0c;默认100%&#xff1b; 父元素有固定宽度时&#xff0c;以父元素宽度为准&#x…

【模电】常见经典运放电路(持续更新)

反相 反相输入比例电路 仿真文件 链接&#xff1a;https://pan.baidu.com/s/1nft1B3mgNpoPfgWo6pFE1g?pwdfpd2 提取码&#xff1a;fpd2 同相 同相输入比例电路 仿真文件 链接&#xff1a;https://pan.baidu.com/s/151OzVgJ2M1iLJ9GCH3xp_A?pwdelec 提取码&#xff1a;…

ROS1快速入门学习笔记 - 04创建工作环境与功能包

一、定义 工作空间(workspace)是一个存放工程开发相关文件的文件夹。 src:代码空间&#xff08;Source Space&#xff09;build: 编辑空间&#xff08;Build Space&#xff09;devel:开发空间&#xff08;Development Space&#xff09;install:安装空间&#xff08;Install …

OpenHarmony实战开发-页面布局检查器ArkUI Inspector使用指导

DevEco Studio内置ArkUI Inspector工具&#xff0c;开发者可以使用ArkUI Inspector&#xff0c;在DevEco Studio上查看应用在真机上的UI显示效果。利用ArkUI Inspector工具&#xff0c;开发者可以快速定位布局问题或其他UI相关问题&#xff0c;同时也可以观察和了解不同组件之间…

TiDB 6.x 新特性解读 | Collation 规则

对数据库而言&#xff0c;合适的字符集和 collation 规则能够大大提升使用者运维和分析的效率。TiDB 从 v4.0 开始支持新 collation 规则&#xff0c;并于 TiDB 6.0 版本进行了更新。本文将深入解读 Collation 规则在 TiDB 6.0 中的变更和应用。 引 这里的“引”&#xff0c;…

【服务器部署篇】Linux下Ansible安装和配置

作者介绍&#xff1a;本人笔名姑苏老陈&#xff0c;从事JAVA开发工作十多年了&#xff0c;带过刚毕业的实习生&#xff0c;也带过技术团队。最近有个朋友的表弟&#xff0c;马上要大学毕业了&#xff0c;想从事JAVA开发工作&#xff0c;但不知道从何处入手。于是&#xff0c;产…

碳课堂|什么是碳市场?如何进行碳交易?

近年来&#xff0c;随着全球变暖问题日益受到重视&#xff0c;碳达峰、碳中和成为国际社会共识&#xff0c;为更好地减缓和适应气候变化&#xff0c;同时降低碳关税风险&#xff0c;以“二氧化碳的排放权利”为商品的碳交易和碳市场应时而生。 一、什么是碳交易、碳市场 各国…

python爬虫 - 爬取html中的script数据(36kr.com新闻信息)

文章目录 1. 分析页面内容数据格式2. 使用re.findall方法&#xff0c;爬取新闻3. 使用re.search 方法&#xff0c;爬取新闻 1. 分析页面内容数据格式 打开 https://36kr.com/ 按F12&#xff08;或 在网页上右键 --> 检查&#xff08;Inspect&#xff09;&#xff09; 找…

17.Nacos与Eureka区别

Nacos会将服务的提供者分为临时实例和非临时实例。默认为临时实例。 临时实例跟eureka一样&#xff0c;会向注册中心报告心跳监测自己是否还活着。如果不正常了nacos会剔除临时实例。&#xff08;捡来的孩子&#xff09; 非临时实例&#xff0c;nacos会主动询问服务提供者是否…

Unity进阶之ScriptableObject

目录 ScriptableObject 概述ScriptableObject数据文件的创建数据文件的使用非持久数据让其真正意义上的持久ScriptableObject的应用配置数据复用数据数据带来的多态行为单例模式化的获取数据 ScriptableObject 概述 ScriptableObject是什么 ScriptableObject是Unity提供的一个…

ElasticSearch笔记一

随着这个业务的发展&#xff0c;我们的数据量越来越庞大。那么传统的这种mysql的数据库就渐渐的难以满足我们复杂的业务需求了。 所以在微服务架构下一般都会用到一种分布式搜索的技术。那么今天呢我们就会带着大家去学习分布搜索当中最流行的一种ElasticSearch&#xff0c;Ela…

锂电池3.7V-4.2V降3.3V2.8V同步降压WT6015

锂电池3.7V-4.2V降3.3V2.8V同步降压WT6015 WT6015 是一款高效单片同步步降稳压器&#xff0c;采用恒定频率和电流模式架构。该设备提供可调节版本&#xff0c;适应不同的应用需求。在无负载条件下&#xff0c;其电源电流仅为40微安&#xff0c;而在关断状态下&#xff0c;电流…

HTB Runner

Runner User Nmap ──(root㉿kali)-[/home/…/machine/SeasonV/linux/Runner] └─# nmap -A runner.htb -T 4 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-04-22 23:07 EDT Stats: 0:00:01 elapsed; 0 hosts completed (1 up), 1 undergoing SYN Stealth Sca…