Dataloader数据集的制作

数据集Dataloader制作

如何自定义数据集:

  • 1.数据和标签的目录结构先搞定(得知道到哪读数据)
  • 2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
  • 3.完成单个数据与标签读取函数(给dataloader举一个例子)

咱们以花朵数据集为例:

  • 原来数据集都是以文件夹为类别ID,现在咱们换一个套路,用txt文件指定数据路径与标签(实际情况基本都这样)
  • 这回咱们的任务就是在txt文件中获取图像路径与标签,然后把他们交给dataloader
  • 核心代码非常简单,按照对应格式传递需要的数据和标签就可以啦
    在这里插入图片描述
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

先来分细节整明白咱一会要干啥!

任务1:读取txt文件中的路径和标签

  • 第一个小任务,从标注文件中读取数据和标签
  • 至于你准备存成什么格式,都可以的,一会能取出来东西就行
def load_annotations(ann_file):data_infos = []with open(ann_file) as f:samples = [x.strip().split(' ')for x in f.readlines]
def load_annotations(ann_file):data_infos = {}with open(ann_file) as f:samples = [x.strip().split(' ') for x in f.readlines()]for filename, gt_label in samples:data_infos[filename] = np.array(gt_label, dtype=np.int64)return data_infos

在这里插入图片描述

任务2:分别把数据和标签都存在list里

  • 不是我非让你存list里,因为dataloader到时候会在这里取数据
  • 按照人家要求来,不要耍个性,让整list咱就给人家整
img_label = load_annotations('./flower_data/train.txt')
image_name = list(img_label.keys())
label = list(img_label.values())

任务3:图像数据路径得完整

  • 因为一会咱得用这个路径去读数据,所以路径得加上前缀
  • 以后大家任务不同,数据不同,怎么加你看着来就行,反正得能读到图像
data_dir = './flower_data/'
train_dir = data_dir + '/train_filelist'
valid_dir = data_dir + '/val_filelist'
image_path = [os.path.join(train_dir,img) for img in image_name]
image_path

任务4:把上面那几个事得写在一起

  • 1.注意要使用from torch.utils.data import Dataset, DataLoader
  • 2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
  • 3.def init(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
  • 4.def getitem(self, idx):根据自己任务,返回图像数据和标签数据
from torch.utils.data import Dataset, DataLoader
class FlowerDataset(Dataset):def __init__(self, root_dir, ann_file, transform=None):self.ann_file = ann_fileself.root_dir = root_dirself.img_label = self.load_annotations()self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]self.label = [label for label in list(self.img_label.values())]self.transform = transformdef __len__(self):return len(self.img)def __getitem__(self, idx):image = Image.open(self.img[idx])label = self.label[idx]if self.transform:image = self.transform(image)label = torch.from_numpy(np.array(label))return image, labeldef load_annotations(self):data_infos = {}with open(self.ann_file) as f:samples = [x.strip().split(' ') for x in f.readlines()]for filename, gt_label in samples:data_infos[filename] = np.array(gt_label, dtype=np.int64)return data_infos

任务5:数据预处理(transform)

  • 1.预处理的事都在上面的__getitem__中完成,需要对图像和标签咋咋地的,要整啥事,都在上面整
  • 2.返回的数据和标签就是建模时模型的输入和损失函数中标签的输入,一定整明白自己模型要啥
  • 3.预处理这个事是你定的,不同的数据需要的方法也不一样,下面给出的是比较通用的方法
data_transforms = {'train': transforms.Compose([transforms.Resize(64),transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选transforms.CenterCrop(64),#从中心开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差]),'valid': transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

任务6:根据写好的class FlowerDataset(Dataset):来实例化咱们的dataloader

  • 1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)
  • 2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
  • 3.打印看看数据里面是不是有东西了
train_dataset = FlowerDataset(root_dir=train_dir, ann_file = './flower_data/train.txt', transform=data_transforms['train'])
val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = './flower_data/val.txt', transform=data_transforms['valid'])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

任务7:用之前先试试,整个数据和标签对应下,看看对不对

  • 1.别着急往模型里传,对不对都不知道呢
  • 2.用这个方法:iter(train_loader).next()来试试,得到的数据和标签是啥
  • 3.看不出来就把图画出来,标签打印出来,确保自己整的数据集没啥问题
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
image, label = iter(val_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))

任务8:咋用就是你来定了,把模型啥的整好往里面传吧

  • 下面这些事之前都唠过了,按照自己习惯的方法整就得了
dataloaders = {'train':train_loader,'valid':val_loader}
model_name = 'resnet'  #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
#是否用人家训练好的特征来做
feature_extract = True 
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available.  Training on CPU ...')
else:print('CUDA is available!  Training on GPU ...')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = models.resnet18()
model_ft
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
input_size = 64
model_ft
# 优化器设置
optimizer_ft = optim.Adam(model_ft.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
criterion = nn.CrossEntropyLoss()
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, filename='best.pth'):since = time.time()best_acc = 0model.to(device)val_acc_history = []train_acc_history = []train_losses = []valid_losses = []LRs = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train()  # 训练else:model.eval()   # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)#print(loss)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),#优化器的状态信息}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)scheduler.step(epoch_loss)#学习率衰减if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果,等着一会测试model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/1381961.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

RabbitMQ 消息队列(Spring boot AMQP)

文章目录 🍰有几个原因可以解释为什么要选择 RabbitMQ:🥩mq之间的对比🌽RabbitMQ vs Apache Kafka🌽RabbitMQ vs ActiveMQ🌽RabbitMQ vs RocketMQ🌽RabbitMQ vs Redis 🥩linux docke…

Android App消息推送 实现原理

https://www.jianshu.com/p/b61a49e0279f 1.消息推送的实质 实际上,是当服务器有新消息需推送给用户时,先发送给应用App,应用App再发送给用户 2. 作用产品角度:功能需要,如:资讯类产品的新闻推送、工具类…

App消息推送 实现原理

1.消息推送的实质 实际上,是当服务器有新消息需推送给用户时,先发送给应用App,应用App再发送给用户 2. 作用 产品角度:功能需要,如:资讯类产品的新闻推送、工具类产品的公告推送等等 运营角度:活…

浏览器及app消息推送

消息推送 什么是消息推送PC端的实现方法1:Notification方法2:pushjs APP端实现打包设置 什么是消息推送 消息推送可以存在于浏览器端,也存在APP端。浏览器的推送,会在电脑通知中显示,app中显示在通知栏 PC端的实现 方法1:Notif…

IOS推送-pushy

iOS 引入jar包创建APNSConnect进行发送报错对照表 引入jar包 创建APNSConnect 创建APNSConnect,与APNs进行链接 public class APNSConnect { private static ApnsClient apnsClient null;public static ApnsClient getAPNSConnectP8(String path,String teamId,S…

unipush+java+个推实现app消息推送

“ 你现在的气质里,藏着你走过的路,读过的书,和爱过的人。 ” 整体还是比较简单地,就是有一些需要注意的地方,很多问题官方文档里面也写了,这里总结一下 对于安卓,谷歌本来有专门的推送通道&…

uniapp - App 超详细消息推送功能实现,从 0-1 实现官方 unipush 推送全步骤稳定性毋庸置疑(附带详细的可运行示例源码和注释,保证 100% 完美接入)苹果安卓手机

效果图 网上的教程太乱用不了,无法改造成自己想要的效果。 在uniapp中开发的app(安卓苹果),使用 unipush 官方推送,从0-1实现完整过程及功能开发。 你可以直接复制示例源码,跟着教程一步步配置,注释详细! 准备 消息

Android 项目必备(三十八)-->APP 消息推送

文章目录 前言推送的实现方式1. C2DM2. 轮询3. SMS信令推送4. MQTT协议5. XMPP协议6. 使用第三方平台 Android 中 MQTT 的使用1. 集成2. 具体代码3. 项目地址 前言 今天来讲讲推送这件小事,事虽小,要做好却不容易。 推送难,难于上青天。 我们…

APP消息推送(APP Push)解决方案-服务端工作逻辑和实现

一、APP 推送概述: App推送消息是我们常见的一种app消息提醒方式。 我们的实现需要第三方的支持,实现方式是后台通过接口将Push请求发送至第三方,第三方实现在App所在设备上的推送。 二、APP推送后台处理逻辑: 在与推送平台交互时…

app消息推送的详细实现教程

实现的主要思想 app实现消息推送,利用的是第三方的个推平台,后端将需要推送的内容通过第三方个推服务器传递给手机端。 具体前端打包配置 根据上图可知,采用的打包软件是Hbuilder X,在模块配置的时候,勾选push模块中的uniPush。…

App消息推送的原理

文章目录 1. 基本概念2. iOS和Android消息推送原理对比2.1 iOS2.1.1 基本原理2.1.2 优劣势 2.2 Android2.2.1 基本原理2.2.2 优劣势 3. Android消息推送原理3.1 操作系统有自身的消息推送功能(系统级别)3.2 三种基本的推送方式:Push、Pull 和…

php实现app消息推送

如何用php实现APP消息推送 现在有很多的消息推送厂商,比如阿里云的消息推送,极光推送,融云的消息推送。他们的原理都是把sdk内置在app里面,达到消息推送的目的,通过一张图来了解一下,看不懂不要紧&#xf…

Android,ios,安卓app推送消息通知,java后台向手机推送app的通知教程

文章目录 一、业务介绍1.1 产品简介1.2 名词解释1.3 消息推送流程 二、应用创建三、客户端 SDK 集成3.1 Android3.2 iOS 四、服务端推送4.1 服务端消息下发流程(必读)4.2 开发者中心后台4.3 推送代码 五、参数说明 一、业务介绍 1.1 产品简介 个推是商…

App消息推送概述

消息推送介绍 消息推送(Push),是指从云端服务器到手机终端的消息推送通道,运营人员可以通过自己产品后台或者第三方推送通道对用户移动设备进行主动的消息推送。通过消息推送,目标用户可以在移动设备通知和状态栏看到…

PushDeer:一种无APP的通知推送解决方案

概述 去年六月,我曾写下一篇博客介绍如何 借助 ServerChan 实现个人微信通知推送,在那篇文章中介绍了 ServerChan 及其使用方法,总的来说,对于简单的通知需求,使用 ServerChan 是非常简单有效的。但是实际使用起来&…

一文让你知道关于App推送那些事

推送相关介绍 在用户未打开App时,服务端向用户推送服务器最新的消息数据,称为推送。消息推送在移动开发中用到的场景非常多,比如典电商类app的商品促销活动,资讯类的app的新闻推送等等。在实际开发中,我们常常会根据产…

关于ISO27701隐私信息安全管理体系介绍

01 什么是ISO27701 ISO27701是对ISO27001信息安全管理和ISO27002安全控制的隐私扩展,全称《安全技术—扩展ISO27001和ISO27002的隐私信息管理—要求与指南》,是ISO标准委员会以ISO 27001为基准,以ISO27552为蓝本,建立发布的隐私…

双向循环链表、dancing links

目录 双向循环链表 力扣 426. 将二叉搜索树转化为排序的双向链表 十字交叉双向循环链表(dancing links) 精确覆盖问题 dancing links X算法(V1递归版) POJ 3740 Easy Finding 数独 X算法优化 X算法(V2非递归…

jpg照片太大怎么压缩变小?jpg如何缩小图片大小kb?

我们平时在接收过多的jpg格式图片的时候,越大的图片虽然越清晰,但是接收和储存起来就非常不方便,那么有没有什么办法可以将jpg图片压缩呢?其实现在可以通过在线图片处理工具来完成jpg压缩(https://www.yasuotu.com/jpg…

html宽度一变小图形就上去,如何把图片大小变小?

我们在布局图片列表时,通常我们要控制图片的高度和宽度这样来达到图片统一。我们在HTML布局时候直接在图片img标签加宽度和高度属性即可控制图片高和宽。 一、html img图片标签高度宽度设置 我们可以直接在图片标签设置宽度width和高度height,这里需要注…