《PyTorch深度学习实践》第八讲加载数据集

一、

1、DataSet 是抽象类,不能实例化对象,主要是用于构造我们的数据集

2、DataLoader 需要获取DataSet提供的索引[i]和len;用来帮助我们加载数据,比如说做shuffle(提高数据集的随机性),batch_size,能拿出Mini-Batch进行训练。它帮我们自动完成这些工作。DataLoader可实例化对象。DataLoader is a class to help us loading data in Pytorch.

3、__getitem__目的是为支持下标(索引)操作
 

二、

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# training cycle forward, backward, update
if __name__ == '__main__':for epoch in range(100):for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batchinputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

1、需要mini_batch 就需要import DataSet和DataLoader

2、继承DataSet的类需要重写init,getitem,len魔法函数。分别是为了加载数据集,获取数据索引,获取数据总量。

3、DataLoader对数据集先打乱(shuffle),然后划分成mini_batch。

4、len函数的返回值 除以 batch_size 的结果就是每一轮epoch中需要迭代的次数。

5、inputs, labels = data中的inputs的shape是[32,8],labels 的shape是[32,1]。也就是说mini_batch在这个地方体现的

6、diabetes.csv数据集老师给了下载地址,该数据集需和源代码放在同一个文件夹内。

问题:loss没有收敛

网友解决:

做了两个实验:(1)输出每批次的loss,不收敛,loss在0.6上下浮动(2)每个epoch都不分批,把所有样本都输入,收敛,最后结果在0.6附近。所以猜测:小样本之间的loss差距相对于0.6而言有点大,所以看着像是没收敛,实际上从总loss来看已经收敛了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2821736.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Entry First Day 入职恩孚第一天

入职第一天,电脑还没配置好就去了工厂。 熟悉了一下设备,切了几个小玩意, hello world 一下。 了解了串行端口的Nodejs的库 https://github.com/serialport/node-serialport,以后要用这个东西和硬件通讯,安装&#…

【C++精简版回顾】15.继承派生

1.继承派生的区别 继承:子继父业,就是子类完全继承父类的全部内容 派生:子类在父类的基础上发展 2.继承方式 1.public继承为原样继承 2.protected继承会把public继承改为protect继承 3.private继承会把public,protected继承改为pr…

指针与malloc动态内存申请,堆和栈的差异

定义了两个函数print_stack()和print_malloc(),分别演示了两种不同的内存分配方式:栈内存和堆内存。然后在main()函数中调用这两个函数,并将它们返回的指针打印出来。 由于print_stack()中的数组c是在栈上分配的,当函数返回后&…

【论文笔记】An Effective Adversarial Attack on Person Re-Identification ...

原文标题(文章标题处有字数限制): 《An Effective Adversarial Attack on Person Re-Identification in Video Surveillance via Dispersion Reduction》 Abstract 通过减少神经网络内部特征图的分散性攻击reid模型。 erbloo/Dispersion_r…

【论文阅读-基于VilLBERT方法的导航】Vison-Language Navigation 视觉语言导航(2)

文章目录 1. 【2023ICCV】Learning Vision-and-Language Navigation from YouTube Videos摘要和结论引言Building VLN Dataset from YouTube Videos模型框架实验 2. 【2021ICCV】Airbert: In-domain Pretraining for Vision-and-Language Navigation摘要和结论引言BnB DatasetA…

爬取某牙视频

爬取页面链接:游戏视频_游戏攻略_虎牙视频 爬取步骤:点进去一个视频播放,查看media看有没有视频,发现没有。在xhr中发现有许多ts文件,但这种不是很长的视频一般都有直接的播放链接,所以目标还是找直接的链…

逻辑漏洞(pikachu)

#水平,垂直越权,未授权访问 通过个更换某个id之类的身份标识,从而使A账号获取(修改、删除)B账号数据 使用低权限身份的账号,发送高权限账号才能有的请求,获得其高权限操作 通过删除请求中的认…

爱普生的SG2016系列高频,低相位抖动spxo样品

精工爱普生公司(TSE: 6724,“爱普生”)已经开始发货样品的新系列简单封装晶体振荡器(SPXO)与差分输出1。该系列包括SG2016EGN、SG2016EHN、SG2016VGN和SG2016VHN。它们在基本模式下都具有低相位抖动,并且采用尺寸为2.0 x 1.6 mm的小封装,高度…

【兔子机器人】五连杆运动学解算与VMC(virtual model control)

VMC (virtual model control,虚拟模型控制) 是一种直觉控制方式,其关键是在每个需要控制的自由度上构造恰当的虚拟构件以产生合适的虚拟力。虚拟力不是实际执行机构的作用力或力矩,而是通过执行机构的作用经过机构转换而成。对于一些控制问题…

云游戏:畅享3A游戏大作的全新时代

在科技飞速发展的今天,云游戏以其独特的魅力正深刻改变着游戏的玩法方式。无需昂贵硬件,突破设备限制畅玩3A大作,云游戏为玩家们带来了前所未有的游戏乐趣。本文将深入探讨云游戏的核心优势,为你呈现畅玩游戏的全新时代。 1. 无硬…

Go 互斥锁的实现原理?

Go sync包提供了两种锁类型:互斥锁sync.Mutex 和 读写互斥锁sync.RWMutex,都属于悲观锁。 概念 Mutex是互斥锁,当一个 goroutine 获得了锁后,其他 goroutine 不能获取锁(只能存在一个写者或读者,不能同时…

JVM内存回收算法

1.1 引用计数法 每个对象创建的时候,会分配一个引用计数器,当这个对象被引用的时候计数器就加1,当不被引用或者引用失效的时候计数器就会减1。任何时候,对象的引用计数器值为0就说明这个对象不被使用了,就认为是“垃圾…

从8.8到9.9,涨价的库迪还能守住牌局吗?

作者 | 辰纹 来源 | 洞见新研社 历经超半年的9.9元活动后,瑞幸不仅牢牢守稳盈利态势,还一举创造了新的神话——中国地区年收入首超星巴克。 根据瑞幸咖啡发布的截至12月31日的2023年第四季度及全年财报。第四季度,瑞幸咖啡净营收为70.6亿元…

【VTKExamples::PolyData】第四十一期 PointLocator

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 前言 本文分享VTK样例PointLocator,并解析接口vtkPointLocator,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 1. PointLocator …

ubuntu16.04安装Mysql8.0.25

更换数据源(阿里云) sudo cp /etc/apt/sources.list /etc/apt/sources.list.bak sudo vi /etc/apt/sources.listdeb http://mirrors.aliyun.com/ubuntu/ xenial main multiverse restricted universe deb http://mirrors.aliyun.com/ubuntu/ xenial-backports main multiverse…

第四十三天| 1049. 最后一块石头的重量 II、494. 目标和、474.一和零

01背包问题 Leetcode 1049. 最后一块石头的重量 II 题目链接:1049 最后一块石头的重量 II 题干:有一堆石头,用整数数组 stones 表示。其中 stones[i] 表示第 i 块石头的重量。 每一回合,从中选出任意两块石头,然后将…

Harbor高可用(haproxy和keepalived)

Harbor高可用(haproxy和keepalived) 文章目录 Harbor高可用(haproxy和keepalived)1.Harbor高可用集群部署架构1.1 主机初始化1.1.1 设置网卡名和ip地址1.1.2 设置主机名1.1.3 配置镜像源1.1.4 关闭防火墙1.1.5 禁用SELinux1.1.6 设…

Typora快捷键设置详细教程(内附每个步骤详细截图)

😎 作者介绍:我是程序员洲洲,一个热爱写作的非著名程序员。CSDN全栈优质领域创作者、华为云博客社区云享专家、阿里云博客社区专家博主、前后端开发、人工智能研究生。公粽号:程序员洲洲。 🎈 本文专栏:本文…

探秘Python的Pipeline魔法

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站AI学习网站。 目录 前言 什么是Pipeline? Pipeline的基本用法 Pipeline的高级用法 1. 动态调参 2. 并行处理 3. 多输出 …

第四十七回 一丈青单捉王矮虎 宋公明二打祝家庄-强大而灵活的python装饰器

四面全是埋伏,宋江和众人一直绕圈跑不出去。正在慌乱之时,石秀及时赶到,教大家碰到白杨树就转弯走。走了一段时间,发现围的人越来越多,原来祝家庄以灯笼指挥号令。花荣一箭射下来红灯龙,伏兵自己就乱起来了…