Neural Network学习笔记3

损失函数和反向传播网络

在进行损失函数计算后,再进行.backward()反向传播。

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("dataset_transform",train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=1)class Zrf(nn.Module):def __init__(self):super(Zrf, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 分类问题适合用交叉熵
loss = nn.CrossEntropyLoss()
zrf = Zrf()
for data in dataloader:imgs, targets = dataoutputs = zrf(imgs)# print(outputs)# print(targets)result_loss = loss(outputs, targets)result_loss.backward()

优化器

  以Adadelta为例,torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)

params: 模型的参数,让优化器知道我们的模型长什么样子。

lr: Learning rate, 学习率

其他的参数可以采用默认,并且优化算法不同,参数也会有很大不同。

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("dataset_transform",train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=1)class Zrf(nn.Module):def __init__(self):super(Zrf, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 分类问题适合用交叉熵
loss = nn.CrossEntropyLoss()
zrf = Zrf()
# 设置优化器
# SGD 随机梯度下降
# lr的设不可以太大也不可以太小,一般情况下我们采用训练开始时lr大,之后的训练中lr小的方式
optim = torch.optim.SGD(zrf.parameters(), lr=0.01)
for epoch in range(20) :running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = zrf(imgs)result_loss = loss(outputs, targets)# 在进行反向传播来计算梯度时,要先将梯度置为0,防止之前计算出来的梯度的影响optim.zero_grad()result_loss.backward()# 根据梯度对卷积核参数进行调优optim.step()running_loss = running_loss + result_lossprint(running_loss)

现有网络模型的使用及修改

torchvision.modles中的VGG为例。VGG常用VGG16和VGG19。

weights: 可选,要使用的预训练权重。默认情况下,不使用预先训练的权重。

progress: true时,会展示一个进度条

此外,pytorch在下载模型时会把模型下载到C盘,下面语句可以修改下载位置:

os.environ['TORCH_HOME'] = '/path/to/torch_home'

import torchvision
from torch import nn
from torchvision.models import VGG16_Weights
import os# train_data = torchvision.datasets.ImageNet(root="data_image_net", split="train", download=True,
#                                            transform=torchvision.transforms.ToTensor())# 最新版本默认是没有预训练的,需要使用预训练设置weights='DEFAULT'os.environ['TORCH_HOME'] = '/path/to/torch_home'vgg16_noPre = torchvision.models.vgg16()
vgg16_pre = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
print(vgg16_pre)# 微调网络模型
train_data = torchvision.datasets.CIFAR10("dataset_transform",train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 添加一层
# 在vgg整体层面上加
vgg16_pre.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_pre)
# 只在某一部分加(classifier部分)
vgg16_pre.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_pre)# 修改
vgg16_noPre.classifier[6] = nn.Linear(4096, 10)
print(vgg16_noPre)

网络模型的保存与读取

保存方法演示:model_save.py

import torch
import torchvision
from torch import nn# 使用未经过训练的,初始化的参数
vgg16 = torchvision.models.vgg16()
# 保存方式1
# 这样不仅保存了网络模型的结构,也保存了网络模型的参数
torch.save(vgg16, "vgg16_method1.pth")# 保存方式2
# 不保存网络的结构,只是把网络的参数保存成数据字典,也就是保存了网络的状态
# (官方推荐!!!)占用空间更小
torch.save(vgg16.state_dict(), "vgg16_method2.pth")# 陷阱1
class Zrf(nn.Module):def __init__(self):super(Zrf, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return xzrf = Zrf()
torch.save(zrf, "zrf_method1.pth")

读取方法演示:model_load.py

import torch
import torchvision
from torch import nn# 方式1 ----> 对应保存方式1,来加载模型
model = torch.load("vgg16_method1.pth")
# print(model)# 方式2 ---> 对应保存方式2
vgg16 = torchvision.models.vgg16() # 新建网络模型结构
model = torch.load("vgg16_method2.pth")
# print(model)
vgg16.load_state_dict(model) # 加载网络模型的状态
print(vgg16)# 陷阱1
# 在使用自己创建的网络时,注意要有网络的这个类在本文件(程序可以访问到网络),只是不需要zrf = Zrf()再创建网络了
# 方法1
# class Zrf(nn.Module):
#     def __init__(self):
#         super(Zrf, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         return x
# 方法2
from model_save import *model = torch.load("zrf_method1.pth")
print(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/351071.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

迅雷远程与服务器失去响应,#原创新人# 彻底解决迅雷关闭接口对群晖NAS的影响...

#原创新人# 彻底解决迅雷关闭接口对群晖NAS的影响 2017-07-29 12:00:07 92点赞 668收藏 186评论 从7月17号开始,陆陆续续有小伙伴在各种论坛开始抱怨迅雷的远程下载被封停了,而时不时抽风说CAPTCHA验证码出问题的离线下载的体验也是相当不好,对相当一部分的NAS用户来说影响还…

Python3迅雷vip账号批量抓取导入excel中

1.抓取思路 打开迅雷粉列表页,找到最新账号 为了保证时效,选择第一个列表页进行抓取 利用正则表达式将账号密码提取 账号:([A-Za-z0-9]{6,}) 密码:([A-Za-z0-9]{6,}) 将筛选出的数据利用openpyxl导入到excel中 本次教程结束&am…

写个小爬虫爬下迅雷会员

好久不写爬虫。。。忘了个锤子。于是借着学java的机会用java写个爬虫 爬取静态页面 迅雷会员账号和密码。时时获取最新的~ 先上我暑假写的python版~ : # -*- coding: utf-8 -*- import urllib import re import os url1 http://xlfans.com/ regex r迅雷会员(.…

LNMP服务

目录 一、安装Nginx服务 1.编译安装nginx服务 2.添加nginx系统服务 二、安装Mysql服务 1.编译安装mysql服务 2.修改mysql配置文件 3.设置路径环境变量 4.初始化数据库 5.添加mysql系统服务 6.修改mysql 的登录密码 三、安装配置 PHP 解析环境 1.安装环境依赖包 2.编…

深圳市有什么靠谱的PMP机构推荐吗?

PMP项目管理专业人士资格认证是由美国项目管理协会(Project Management Institute,简称PMI)发起的。PMP作为世界级的项目管理认证证书,拥有着最先进的项目管理知识体系,它严格评估项目管理人员知识技能是否具有高品质的…

HarmonyOS3 Stage模型介绍

Stage模型是HarmonyOS 3.1 Develper Preview(API 9)版本开始新增的模型,也是目前HarmonyOS主推且会长期演进的模型。在该模型中,由于提供了AbilityStage、WindowStage等类作为应用组件和Window窗口的“舞台”,因此称这…

构建智能电商推荐系统:大数据实战中的Kudu、Flink和Mahout应用【上进小菜猪大数据】

上进小菜猪,沈工大软件工程专业,爱好敲代码,持续输出干货。 本文将介绍如何利用Kudu、Flink和Mahout这三种技术构建一个强大的大数据分析平台。我们将详细讨论这些技术的特点和优势,并提供代码示例,帮助读者了解如何在…

投影仪哪个牌子好?怎么选家用投影仪

这两年看我身边好多朋友都买了投影仪,我心里也是痒痒的。他们都说有了投影仪之后再也不用去电影院了,周末在家拉上窗帘,准备一堆甜品奶茶,躺在沙发上就可以开始享受家庭影院了。不过我在想,投影仪的牌子这么多&#xf…

十大国产投影仪品牌:极米、当贝、明基、小米等国产投影仪大牌厂商

投影仪,想必网友都比较熟悉了,目前很多家庭里都购入了投影仪作为观影设备;特别是在近些年,笔者不少朋友也选择投影仪给孩子们使用。经过多年的技术开发与经验累积,国产投影仪已成为高销量、高品质的品牌。 十大国产投…

家用投影仪什么品牌好?投影仪哪家好?

最近好多朋友都在问我家用投影仪什么品牌好。但是我的观点是选任何一个产品都不能只看品牌,一定先要了解如何选择该类产品,才能选出兼顾产品和品牌两个方面的好东西。因此,这篇文章将会先告诉大家如何选择投影仪,然后再给大家介绍…

投影仪参数哪些最重要?什么品牌投影仪好

大家都知道投影仪规格参数多,包括系统配置、显示参数、音效在内,大大小小总共几十个,要是每个都摘出来详细对比的话,也太花功夫了。俗话说得好,打蛇打七寸。与其面面俱到,不如分清主次,抓大放小…

怎么挑选投影仪?高清投影仪什么品牌好

随着家庭智能影院的兴起,投影仪逐渐成为家庭观影的一种新潮流。那么投影仪应该怎么挑选呢? 我们在选择投影仪的时候要注重不同功能参数之间的对比,下面将我自己选择投影仪的一些经验分享给大家。 首先看分辨率,我们都知道分辨率是…

投影仪什么牌子最好?哪款投影仪做家庭影院效果好

这几年来国内新兴的投影仪牌子不计其数,除了几个占据行业领先地位的老牌子之外,很多新起之秀也蓄势待发,在打造极致性价比方面卯足了劲儿,跟大牌竞争。 与其问现在什么牌子的投影仪最好,还不如自己学会看投影仪的参数&…

投影仪哪些比较好?投影仪如何选购

现在在家里装投影仪能提升幸福感,很多小伙伴准备入坑。但看到市场上那么多品牌和款型,不知道投影仪哪些比较好。接下来和大家分享自己的选购经验,后半部分整理出来了口碑比较好的部分产品,希望能帮助大家缩小选择的范围。 挑选指南…

投影仪家里用什么牌子好?哪种投影仪性价比高

人们对生活品质的追求,已经体现在投影仪上。不管是买房还是租房,都可以在家享受大屏电影的体验。可看着这么多牌子,很多人可能不知道怎么选。其实只要会看参数,就能知道投影仪家里用什么牌子好了。 1、显示芯片、分辨率 家里用的…

什么牌子投影仪好?投影仪买什么牌子的好

最近几年投影仪行业发展很快,除了几个传统的品牌,几个新兴的品牌也很受关注。概括起来国内有极米、坚果、大眼橙、明基等,国外有索尼、松下、爱普生。备选一多就容易纠结,很多人问什么牌子投影仪好,下面就分享一下自己…

什么牌子投影仪好?国产投影仪什么牌子好

小巧,智能,易于操作的物品越来越受到人们的喜爱。在科技的进步中,一些影视爱好者也不满足于电影院或者电视等传统观影方式,这也是投影仪越来越受人们欢迎的原因。它兼备了智能化与信息化等多种现代元素,同时能满足观影…

国产家用投影仪十大排名品牌,最新排名整理分享给大家选前要看哦

支持国货现在已然成为国人绝对支持的行为之一,对于像华为、鸿星尔克等国产国货出现火爆的场景,仍然历历在目!现在国产国货已经影响着世界,国际友人都爱上了中国造!今天小编分享新国货十大国产投影仪品牌排行榜&#xf…

投影仪哪个牌子的好?家庭影院投影仪哪款好

近年来的投影仪市场真的太火爆了,各大平台上都在推各种品牌的投影仪,有的是几百块钱价位的,有的是大几千的,还有上万的。作为一名家电行业的技术人员,个人觉得几百块钱的投影仪真心不能买,连智能系统都没有…

mac电脑git clone项目时报错证书过期和权限被拒绝

mac电脑使用git clone命令克隆项目时,一开始一直提示证书过期 SSL certificate problem: certificate has expired 执行以下代码关掉验证后,解决了这个问题 找到git目录 Git\git-cmd输入命令跳转到bin目录,cd bin输入命令运行git.exe执行关…