数据增强,迁移学习,Resnet分类实战

目录

1. 数据增强(Data Augmentation)

2. 迁移学习

3. 模型保存    

4. 102种类花分类实战

1. 数据集

2.导入包

3. 数据读取与预处理操作 

4. Datasets制作输入数据

5.将标签的名字读出 

6.展示原始数据 

7.加载models中提供的模型 

8.初始化 

9.优化器设置 

10.训练模块


1. 数据增强(Data Augmentation)

        数据不够怎么办?采用翻转,镜像,增加数据

        如何更加高效利用数据?多利用几次

        在pytorch中有数据预处理部分:

            数据增强:torchvision中transforms模块自带功能,比较实用

            数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可

            DataLoader模块直接读取batch数据

        pyorch官网:https://pytorch.org/vision/stable

2. 迁移学习

        在训练自己的模型时出现一些问题:

        1. 自己的数据不够好

        2. 训练参数花费时间多

        3. 训练模型太难

        解决方法:

        有前人已经训练好了模型,其实就是将训练的参数保留下来,而且目标都差不多。那么把别人的模型参数当成初始化参数,所有的结构和前人模型一样。

        网络模块设置:

    加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且也可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习

    需要注意的是别人训练好的任务根咱们的可不是完全一样的,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务

    训练时可以完全重头训练,也可以只训练最后咱们任务层,因为前几层都是做特征提取的,本质任务目标一致的。

        总结:迁移学习策略

                1. 将卷积层当成初始化权重参数

                2.将卷积层权重参数冻住不变,全连接层重新训练(一般是,数据量少,冻住的层数多)

3. 模型保存    

        网络模型保存与测试

            模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存

            读取模型进行实际测试

4. 102种类花分类实战

      1. 数据集

        有训练集,测试集。一共102种花,每种花有25~100个图像

2.导入包

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

3. 数据读取与预处理操作 

data_dir = './flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

制作好数据源:

    data_transforms中指定了所有图像预处理操作

    ImageFolder假设所有文件按文件夹保存好,每个文件夹下面存储同一类别的图片,文件夹的名字为分类的名字

data_transforms = {'train' : transforms.Compose([transforms.RandomRotation(45),#随即旋转,-45度到45度之间随机选transforms.CenterCrop(224), #从中心点开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随即水平翻转,选择一个概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1), #参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),#转换成tensor格式transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值,标准差]), 'valid':transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])                   ])
}

 4. Datasets制作输入数据

        采用batch,将数据分组输入。

batch_size  = 8image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x : torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True) for x in ['train','valid']}
dataset_szies = {x :len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes
print(image_datasets)
print(dataloaders)

5.将标签的名字读出 

        用123....打标签好像不好,用花的名字作为标签

#读取标签对应的实际名字
with open('./flower_data/cat_to_name.json','r') as f:cat_to_name = json.load(f)
print(cat_to_name)

6.展示原始数据 

        展示下数据

            注意tensor的数据需要转换成numpy格式,而且还需要还原成标准化的结果

def im_convert(tensor):'''展示数据'''image = tensor.to('cpu').clone().detach()image = image.numpy().squeeze()image = image.transpose(1,2,0)image = image * np.array((0.229,0.224,0.225)) + np.array((0.485,0.456,0.406))image = image.clip(0,1)return imagefig = plt.figure(figsize=(20,12))
colunms = 4
rows = 2dataiter = iter(dataloaders['valid'])
inputs,classes =next(dataiter)for idx in range(colunms * rows):ax = fig.add_subplot(rows,colunms,idx+1,xticks = [],yticks = [])ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])plt.imshow(im_convert(inputs[idx]))
plt.show()

7.加载models中提供的模型 

        加载models中提供的模型,并且直接用训练好的权重当作初始化参数

            第一次执行需要下载,可能会比较慢

model_name = 'resnet' #可选的会比较多['resnet','alexnet','vgg','squeezenet','densenet','inception']
# 是否用人家训练好的特征来做
feature_extract = True#是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available .  Training on CPU...')
else:print('CUDA is  available .  Training on GPU...')device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')def set_parameter_requires_grad(model,feature_extracting):if feature_extract:for param in model.parameters():param.requires_grad = False #冻不冻住model_ft = models.resnet152()
print(model_ft)

8.初始化 

        迁移学习,用前人的参数,改变全连接层。

def initalize_model(model_name,num_classes,feature_extract,use_pretrained=True):#选择合适的模型,不同模型的初始化方法稍微有点区别model_ft = Noneinput_size = 0if model_name == 'resnet':model_ft = models.resnet152(pretrained=use_pretrained)set_parameter_requires_grad(model_ft,feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Sequential(nn.Linear(num_ftrs,num_classes),nn.LogSoftmax(dim=1))input_size = 224return model_ft,input_sizefeature_extract = True
model_ft,input_size = initalize_model(model_name,102,feature_extract,use_pretrained=True)#GPU计算
model_ft = model_ft.to(device)#模型保存
filename = 'checkpoint.pth'#是否训练所有层params_to_updata = model_ft.parameters()
print('Params to learn')
if feature_extract:params_to_updata = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_updata.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)print(model_ft)

 

9.优化器设置 

#优化器设置
optimizer_ft = optim.Adam(params_to_updata,lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) #学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLoss()整合
criterion = nn.NLLLoss()

10.训练模块

def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename=filename):since = time.time()best_acc = 0'''checkpoint = torch.laod(filename)best_acc = checkpoint['best_acc]model.load_state_dict(checkpoint['optimizer'])model.class_to_idx = checkpoint['mapping']'''model.to(device)val_acc_history = []train_acc_history = []train_losses = []vaild_losses = []LRs = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch,num_epochs-1))print('-'*10)#训练和验证for phase in ['train','valid']:if phase == 'train':model.train() #训练else:model.eval()  #验证running_loss = 0.0running_corrects = 0#把数据都取个遍for inputs,labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)#清零optimizer.zero_grad()#只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase=='train'):if is_inception and phase == 'train':outputs,aux_outputs = model(inputs)loss1 = criterion(outputs,labels)loss2 = criterion(aux_outputs,labels)loss = loss1 + loss2else: #resnet执行的是这里outputs= model(inputs)loss = criterion(outputs,labels)_,preds = torch.max(outputs,1)#训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()#计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase,epoch_loss,epoch_acc))#得到最好的那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state ={'state_dict' : model.state_dict(),'best_acc': best_acc,'optimizer': optimizer.state_dict(),}torch.save(state,filename)if phase == 'valid':val_acc_history.append(epoch_acc)vaild_losses.append(epoch_loss)scheduler.step(epoch_loss)if phase=='train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60,time_elapsed % 60))print('Best val Acc: {:.4f}'.format(best_acc))#训练完后用最好的一次当作模型的最终结果model.load_state_dict(best_model_wts)return model, val_acc_history,train_acc_history,vaild_losses,train_losses,LRs
#开始训练!!!
model_ft,val_acc_history,train_acc_history,vaild_losses,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3031983.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

从静态PPT到智能演讲——人工智能在演示文稿中的应用

1.概述 在这个信息过载的时代,能够吸引并持续吸引观众的注意力无疑成为了一项艰巨的任务。公众演讲领域正经历着一场由人工智能(AI)引领的革命。AI不仅在制作引人入胜的内容方面发挥作用,而且在分析演讲的传递方式上也起着关键作…

【C++】 类的6个默认成员函数

目录 1. 类的6个默认成员函数 一.构造函数 1.基本概念 2 特性 注意:C11 中针对内置类型成员不初始化的缺陷,又打了补丁, 3.构造函数详解 3.1构造函数体赋值 3.2 初始化列表 3.3 explicit关键字 二.析构函数 1 概念 2 特性 两个栈实…

Vue路由拆分

1.在src下建立router&#xff0c;在router中建立文件index 2.将main.js中部分内容复制 App <template> <div><a href"#/friend">朋友</a><br><a href"#/info">信息</a><br><a href"#/music&quo…

Photoshop中图层的应用

Photoshop中图层的应用 前言Photoshop中的图层面板Photoshop中图层的基本操作新建图层复制/剪切图层链接图层修改图层名称及颜色背景图层与普通图层栅格化图层图层的对齐与分布图层的合并 前言 图层在Photoshop中就像一层一层的透明纸&#xff0c;可以透过图层的透明区域看到下…

动手学深度学习16 Pytorch神经网络基础

动手学深度学习16 Pytorch神经网络基础 1. 模型构造2. 参数管理1. state_dict()2. normal_() zeros_()3. xavier初始化共享参数的好处 3. 自定义层4. 读写文件net.eval() 评估模式 QA 1. 模型构造 定义隐藏层–模型结构定义前向函数–模型结构的调用 import torch from torch…

万村乐数字乡村综合服务系统如何助力农民收入的腾飞

作为行业领先的数字乡村综合服务系统——“万村乐”&#xff0c;其核心便是基于互联网乡村和物联网乡村的强大信息基石之上。通过幸福民生服务、高效政务服务以及规范的党务服务这三条主线&#xff0c;以手机端平台为承载&#xff0c;借助事件反馈、精准种养数据、精细人员网格…

【Java】/*方法的使用-快速总结*/

目录 一、什么是方法 二、方法的定义 三、实参和形参的关系 四、方法重载 五、方法签名 一、什么是方法 Java中的方法可以理解为C语言中的函数&#xff0c;只是换了个名称而已。 二、方法的定义 1. 语法格式&#xff1a; public static 返回类型 方法名 (形参列表) { //方…

AI领域最伟大的论文检索网站

&#x1f4d1; 苏剑林&#xff08;Jianlin Su&#xff09;开发的“Cool Papers”网站旨在通过沉浸式体验提升科研工作者浏览论文的效率和乐趣。这个平台的核心优势在于利用Kimi的智能回答功能&#xff0c;帮助用户快速了解论文的常见问题&#xff08;FAQ&#xff09;&#xff0…

定了,2024年天门中级职称报名开始了

关于今年天门中级职称报名各类相关事宜&#xff0c;我们一起来看看 一、报名时间和地址 1.报名时间&#xff1a;2024年5月10日至5月22日&#xff0c;并由主管部门或用人单位将报名表提交给人力资源部&#xff08;注意不要错过时间了&#xff09; 水测准考证领取时间为正式考试…

卷积模型的剪枝、蒸馏---蒸馏篇--NST特征蒸馏(以deeplabv3+为例)

本文使用NST特征蒸馏实现deeplabv3+模型对剪枝后模型的蒸馏过程; 一、NST特征蒸馏简介 下面是两张叠加了热力图(heat map)的图片,从图中很容易看出这两个神经元具有很强的选择性:左图的神经元对猴子的脸部非常敏感,右侧的神经元对字符非常敏感。这种激活实际上意味着神经…

自回归模型的优缺点及改进方向

在学术界和人工智能产业中&#xff0c;关于自回归模型的演进与应用一直是一个引发深入讨论和多方观点交锋的热门议题。尤其是Yann LeCun&#xff0c;这位享誉全球的AI领域学者、图灵奖的获得者&#xff0c;以及被誉为人工智能领域的三大巨擘之一&#xff0c;他对于自回归模型持…

笔记2:torch搭建VGG网络代码详细解释

VGG网络结构 VGG网络&#xff08;Visual Geometry Group Network&#xff09;是一种经典的深度学习卷积神经网络&#xff08;CNN&#xff09;架构&#xff0c;由牛津大学的视觉几何组&#xff08;Visual Geometry Group&#xff09;在2014年提出。VGG网络在ImageNet挑战赛2014…

软件开发项目实施方案-精华资料(Word原件)

依据项目建设要求&#xff0c;对平台进行整体规划设计更新维护&#xff0c;对系统运行的安全性、可靠性、易用性以及稳健性进行全新设计&#xff0c;并将所有的应用系统进行部署实施和软件使用培训以及技术支持。 根据施工总进度规划&#xff0c;编制本项目施工进度计划表。依据…

OSPF虚链路

原理概述 通常情况下&#xff0c;一个OSPF网络的每个非骨干区域都必须与骨干区域通过ABR路由器直接连接&#xff0c;非骨干区域之间的通信都需要通过骨干区域进行中转。但在现实中&#xff0c;可能会因为各种条件限制&#xff0c;导致非骨干区域和骨干区域无法直接连接&#x…

在家就可以轻松赚零花钱的副业

互联网的兴起让很多人实现了在家办公的梦想&#xff0c;同时也为人们提供了更多的挣钱方式。以下是4种可以在家中兼职副业赚钱的方法&#xff1a; 1. 写作工作 如果你善于写作&#xff0c;并且有一定的文学素养&#xff0c;那么可以通过自己的博客或其他媒体平台来写作&#…

SMART700西门子触摸屏维修6AV6 648-0CC11-3AX0

西门子工控机触摸屏维修系列型号&#xff1a;PС477,PC677,TD200,TD400,KTP178,TP170A,TP170B,TP177A,TP177B,TP270,TP277,TP27,MP370,MP277,OP27,OP177B等。 触摸屏故障有&#xff1a;上电黑屏, 花屏,暗屏,触摸失灵,按键损坏,电源板,高压板故障,液晶,主板坏等,内容错乱、进不了…

2024.5.6 关于 SpringCloud 的基本认知

目录 引言 微服务框架所包含的技术栈 面试题 微服务架构演变 单体架构 分布式架构 微服务架构 微服务技术对比 认识 SpringCloud SpringBoot 版本兼容关系 服务拆分和远程调用 服务拆分注意事项 远程调用 引入问题 引言 微服务是一种框架风格&#xff0c;按照业务…

解决离线服务器无法加载HuggingFaceEmbeddings向量化模型的问题

由于服务器是离线的&#xff0c;因此我先在本地到huggingface官网下载模型text2vec&#xff0c;然后上传到服务器上运行&#xff0c;报错&#xff1a; (MaxRetryError(HTTPSConnectionPool(host\huggingface.co\, port443): Max retries exceeded with url: /api/models/senten…

Windows 11 Manager (Win11系统优化大师) 中文破姐版 v1.4.3

01 软件介绍 ​Windows 11 Manager v1.4.3是一款综合性的系统优化工具&#xff0c;专为Win11设计。该工具包含超过40种功能&#xff0c;旨在全方位提升操作系统的性能。通过这些工具&#xff0c;用户可以对Windows 11进行深度优化和微调&#xff0c;清除不必要的文件&#xff…

Nginx 从入门到实践(3)——负载均衡、反向代理、动静分离

Nginx代理服务 Nginx代理服务 Nginx代理服务Nginx负载均衡反向代理反向代理的用途 Nginx配置攻略Nginx动静分离使用 Nginx 实现四层代理配置基本介绍使用 Nginx 实现四层代理配置 Nginx负载均衡 负载均衡&#xff08;Load Balance&#xff09;是由多台服务器以对称的方式组成一…