Pytorch 复习总结 3

Pytorch 复习总结,仅供笔者使用,参考教材:

  • 《动手学深度学习》
  • Stanford University: Practical Machine Learning

本文主要内容为:Pytorch 多层感知机。

本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟合现象提出解决办法。


Pytorch 语法汇总:

  • Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
  • Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
  • Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
  • Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
  • Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
  • Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;

目录

  • 一. 多层感知机
    • 1. 读取数据集
    • 2. 神经网络模型
    • 3. 激活函数
    • 4. 损失函数
    • 5. 优化器
    • 6. 训练
  • 二. 过拟合的缓解
    • 1. 权重衰减
    • 2. Dropout

一. 多层感知机

虽然线性模型易于实现和理解、计算成本低、泛化能力强,但是对于一些非线性问题,可能会违反线性模型的单调性。为此,多层感知器引入了隐藏层来克服线性模型的限制,并且加入激活函数以增强网络非线性建模能力。

1. 读取数据集

同 Pytorch 复习总结 2 中 Softmax 回归的数据读取,继续使用 Fashion-MNIST 图像分类数据集:

import torch
import torchvision
from torch.utils import data
from torchvision import transformsdef load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集并将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True),data.DataLoader(mnist_test, batch_size, shuffle=False))batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

2. 神经网络模型

先将输入的图像展平,然后使用 2 个全连接层进行处理,中间的全连接层需要使用激活函数激活,最后一层全连接层作为输出:

from torch import nn
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10)
)

仍然使用 init_weights() 函数按正态分布初始化所有全连接层的权重:

def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

3. 激活函数

上一节使用了 ReLU 函数进行激活,在实际应用中,还可以使用 sigmoid、tanh 等函数激活。ReLU、sigmoid、tanh 函数的梯度可视化如下:

import torch
from matplotlib import pyplot as pltx = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
# y = torch.relu(x)
# y = torch.sigmoid(x)
y = torch.tanh(x)
y.backward(torch.ones_like(x), retain_graph=True)
plt.figure(figsize=(5, 2.5))
plt.plot(x.detach(), x.grad)
plt.show()

4. 损失函数

同 Softmax 回归:

loss = nn.CrossEntropyLoss(reduction='none')

5. 优化器

同 Softmax 回归:

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

6. 训练

同 Softmax 回归,可以将训练过程封装成函数:

def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def train_net(net, train_iter, test_iter, loss, num_epochs, trainer):for epoch in range(num_epochs):     # 迭代训练轮次net.train()                     # 将模型设置为训练模式train_loss_sum = 0.0            # 训练损失总和train_acc_sum = 0.0             # 训练准确度总和sample_num = 0                  # 样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y)trainer.zero_grad()l.mean().backward()trainer.step()train_loss_sum += l.sum()train_acc_sum += accuracy(y_hat, y)sample_num += y.numel()train_loss = train_loss_sum / sample_numtrain_acc = train_acc_sum / sample_numnet.eval()                      # 将模型设置为评估模式test_acc_sum = 0.0test_sample_num = 0for X, y in test_iter:test_acc_sum += accuracy(net(X), y)test_sample_num += y.numel()test_acc = test_acc_sum / test_sample_numprint(f'epoch {epoch + 1}, 'f'train loss {train_loss:.4f}, train acc {train_acc:.4f}, 'f'test acc {test_acc:.4f}')num_epochs = 10
train_net(net, train_iter, test_iter, loss, num_epochs, trainer)

二. 过拟合的缓解

当模型过于复杂、训练数据太少、迭代轮数太多时,就会出现过拟合现象。解决过拟合的方法有很多:

  • 增加数据量:增加训练数据可以帮助模型更好地学习数据的真实规律,减少过拟合的发生;
  • 简化模型:降低模型的复杂度,可以通过减少模型的参数数量、使用正则化等方法来实现;
  • 交叉验证:使用交叉验证来评估模型的泛化能力,选择最优的模型;
  • 提前停止:即 Dropout,在训练过程中监控模型在验证集上的表现,当验证集误差不再下降甚至开始上升时,及时停止训练,防止模型过拟合;
  • 集成学习:使用集成学习方法(如随机森林、梯度提升树等)降低模型的方差,提高泛化能力。

下面介绍几种常用的正则化方法。

1. 权重衰减

权重衰减 (Weight Decay) 通过向损失函数中添加一个惩罚项来减小模型复杂度,以防止过拟合。惩罚项也叫 正则项,通常是权重的平方和(即 L2 范数)或权重的绝对值和(即 L1 范数)乘以一个正则化系数。

以线性回归的损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 为例,使用优化器训练时,在损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 上添加 L2 范数如下:
L ( w , b ) + λ 2 ∥ w ∥ 2 = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b)+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ =\frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right)^2+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ L(w,b)+2λw2=n1i=1n21(wx(i)+by(i))2+2λw2

损失函数中没有添加偏置 b b b 的惩罚项,因为一般情况下,网络输出层的偏置项不需要正则化。代入 w \mathbf{w} w 的参数更新表达式为:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow(1-\eta \lambda) \mathbf{w}-\frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right) w(1ηλ)wBηiBx(i)(wx(i)+by(i))

要想对模型进行权重衰减,只需要在实例化优化器时通过 weight_decay 指定权重衰减参数。默认情况下,PyTorch 同时衰减权重和偏移:

trainer = torch.optim.SGD(net.parameters(), lr=lr)

如果想要只衰减权重,需要指定参数:

params_to_optimize = [{"params": net[0].weight, 'weight_decay': wd},{"params":net[0].bias}
]
trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)

2. Dropout

Dropout 通过在训练过程中随机地将网络 内部 的一部分神经元的输出设置为零,即以一定的概率 “丢弃” 这些神经元。这样可以防止神经元在训练过程中过于依赖其他神经元,从而降低了网络对特定神经元的依赖性,使得网络更具鲁棒性:
在这里插入图片描述

通常情况下,Dropout 只在训练过程中使用,不在推理阶段使用,因为推理时模型需要产生确定性的输出。

Dropout 需要在网络中添加 Dropout 层,一般位于激活函数后,并且给定 dropout 概率:

dropout1, dropout2 = 0.2, 0.5net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),nn.Dropout(dropout2),nn.Linear(256, 10)
)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

Dropout 概率的设置技巧是靠近输入层的地方设置较低的概率,远离输入层的地方设置较高的概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2799586.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

win10编译openjdk源码

上篇文章作者在ubuntu系统上实践完成openjdk源码的编译,但是平常使用更多的是window系统,ubuntu上编译出来JDK无法再windows上使用。所以作者又花费了很长时间在windows系统上完成openjdk源码的编译,陆续花费一个月的时间终于完成了编译。 本…

【Unity3D】ASE制作天空盒

找到官方shader并分析 下载对应资源包找到\DefaultResourcesExtra\Skybox-Cubed.shader找到\CGIncludes\UnityCG.cginc观察变量, 观察tag, 观察代码 需要注意的内容 ASE要处理的内容 核心修改 添加一个Custom Expression节点 code内容为: return DecodeHDR(In0, In1);outp…

[bing]“gang调度 Kubernetes的并发控制和一致性机制“论点的对应的源码分析

你是一位K8S专家。请分析在Kubernates(https://github.com/kubernetes/kubernetes.git)项目和调度coscheduling(https://github.com/kubernetes-sigs/scheduler-plugins/tree/master/pkg/coscheduling) 插件中支撑"PodGroup的管理和调度决策涉及到对…

Elasticsearch:使用 ELSER v2 进行语义搜索

在我之前的文章 “Elasticsearch:使用 ELSER 进行语义搜索”,我们展示了如何使用 ELESR v1 来进行语义搜索。在使用 ELSER 之前,我们必须注意的是: 重要:虽然 ELSER V2 已正式发布,但 ELSER V1 仍处于 [预览…

MKS薄膜规622/626/627/628/629说明接口定义等说明

MKS薄膜规622/626/627/628/629说明接口定义等说明

【day02】每天三道 java后端面试题:Java、C++和Go的区别 | Redis的特点和应用场景 | 计算机网络七层模型

文章目录 1. Java、C和 Go 语言的区别,各自的优缺点?2. 什么是Redis?Redis 有哪些特点? Redis有哪些常见的应用场景?3. 简述计算机网络七层模型和各自的作用? 1. Java、C和 Go 语言的区别,各自的…

Bert基础(三)--位置编码

背景 还是以I am good(我很好)为例。 在RNN模型中,句子是逐字送入学习网络的。换言之,首先把I作为输入,接下来是am,以此类推。通过逐字地接受输入,学习网络就能完全理解整个句子。然而&#x…

【实战】CEF框架集成MFC DLL的一些坑

文章目录 DLL 导出方式,函数指针的坑问题出现问题解决 CEF集成MFCafxwin.h预编译编译参数修改 DLL重命名 MFC作为微软的长期主力开发套件之一,之前很多设备开发的C/S端界面都是通过MFC框架来做的,而在我自己的CEF项目中,会集成很多…

顺序表详解(如何实现顺序表)

文章目录 前言 在进入顺序表前,我们先要明白,数据结构的基本概念。 一、数据结构的基本概念 1.1什么是数据结构 数据结构是由“数据”和“结构”两词组合而来。所谓数据就是?常见的数值1、2、3、4.....、姓名、性别、年龄,等。…

React基础-webpack+creact-react-app创建项目

学习视频:学习视频 2节:webpack工程化创建项目 2.1.webpack工程化工具:vite/rollup/turbopak; 实现组件的合并、压缩、打包等; 代码编译、兼容、校验等; 2.2.React工程化/组件开发 我们可以基于webpack自己去搭建…

LeetCode | 寻找两个正序数组的中位数 Python C语言

Problem: 4. 寻找两个正序数组的中位数 文章目录 思路解题方法Code结果结果一些思考 思路 先合并,后排序,最后找中间轴。 解题方法 由解题思路可知 Code 这是python3的代码。 class Solution(object):def findMedianSortedArrays(self, nums1, num…

IDEA 2021.3激活

1、打开idea,在设置中查找Settings/Preferences… -> Plugins 内手动添加第三方插件仓库地址:https://plugins.zhile.io搜索:IDE Eval Reset 插件进行安装。应用和使用,如图

Spring 类型转换、数值绑定与验证(一)— DataBinder

DataBinder 是Spring用于数据绑定、类型转换及验证的类。使用场景有:1)xml配置文件定义bean,Spring 内部使用DataBinder 来完成属性的绑定;2)Web请求参数绑定,在Spring MVC 中,Controller的方法参数通常会自…

基于机器学习的青藏高原高寒沼泽湿地蒸散发插补研究_王秀英_2022

基于机器学习的青藏高原高寒沼泽湿地蒸散发插补研究_王秀英_2022 摘要关键词 1 材料和方法1.1 研究区概况与数据来源1.2 研究方法 2 结果和分析2.1 蒸散发通量观测数据缺省状况2.2 蒸散发与气象因子的相关性分析2.3 不同气象因子输入组合下各模型算法精度对比2.4 随机森林回归模…

《图解设计模式》笔记(一)适应设计模式

图灵社区 - 图解设计模式 - 随书下载 评论区 雨帆 2017-01-11 16:14:04 对于设计模式,我个人认为,其实代码和设计原则才是最好的老师。理解了 SOLID,如何 SOLID,自然而然地就用起来设计模式了。Github 上有一个 tdd-training&…

第3.4章:StarRocks数据导入-Routine Load

注:本篇文章阐述的是StarRocks-3.2版本的Routine Load导入机制 一、概述 Routine Load(例行导入)支持用户提交一个常驻的导入任务,可以将消息流存储在 Kafka 的Topic中,通过订阅Topic 中的全部或部分分区的消息&#…

多个.C 文件关于全局变量如何使用

𝙉𝙞𝙘𝙚!!👏🏻‧✧̣̥̇‧✦👏🏻‧✧̣̥̇‧✦ 👏🏻‧✧̣̥̇:Solitary_walk ⸝⋆ ━━━┓ - 个性标签 - :来于“云”的“羽球人”。…

C语言新手写函数中出现数组时运行bug的解决

一.发现问题&#xff1a; 这是我今天写代码的一小部分&#xff0c;是创建一个数组&#xff0c;然后函数init&#xff08;&#xff09;是初始化数组&#xff0c;代码如下&#xff1a; void init(int arr[10],unsigned int k) {int i 0;for (i 0; i < k; i) {arr[i] 0;} …

深度学习系列60: 大模型文本理解和生成概述

参考网络课程&#xff1a;https://www.bilibili.com/video/BV1UG411p7zv/?p98&spm_id_frompageDriver&vd_source3eeaf9c562508b013fa950114d4b0990 1. 概述 包含理解和分类两大类问题&#xff0c;对应的就是BERT和GPT两大类模型&#xff1b;而交叉领域则对应T5 2.…

机器学习基本概念(李宏毅课程)

目录 一、概念:1、机器学习概念:2、深度学习概念&#xff1a; 二、深度学习中f(.)的输入和输出&#xff1a;1、输入&#xff1a;2、输出&#xff1a; 三、三种机器学习任务&#xff1a;1、Regression回归任务介绍&#xff1a;2、Classification分类任务介绍&#xff1a;3、Stru…