【深度学习笔记】3_12 权重衰减

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

3.12 权重衰减

上一节中我们观察了过拟合现象,即模型的训练误差远小于它在测试集上的误差。虽然增大训练数据集可能会减轻过拟合,但是获取额外的训练数据往往代价高昂。本节介绍应对过拟合问题的常用方法:权重衰减(weight decay)。

3.12.1 方法

权重衰减等价于 L 2 L_2 L2 范数正则化(regularization)。正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。我们先描述 L 2 L_2 L2范数正则化,再解释它为何又称权重衰减。

L 2 L_2 L2范数正则化在模型原损失函数基础上添加 L 2 L_2 L2范数惩罚项,从而得到训练所需要最小化的函数。 L 2 L_2 L2范数惩罚项指的是模型权重参数每个元素的平方和与一个正的常数的乘积。以3.1节(线性回归)中的线性回归损失函数

ℓ ( w 1 , w 2 , b ) = 1 n ∑ i = 1 n 1 2 ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) 2 \ell(w_1, w_2, b) = \frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(x_1^{(i)} w_1 + x_2^{(i)} w_2 + b - y^{(i)}\right)^2 (w1,w2,b)=n1i=1n21(x1(i)w1+x2(i)w2+by(i))2

为例,其中 w 1 , w 2 w_1, w_2 w1,w2是权重参数, b b b是偏差参数,样本 i i i的输入为 x 1 ( i ) , x 2 ( i ) x_1^{(i)}, x_2^{(i)} x1(i),x2(i),标签为 y ( i ) y^{(i)} y(i),样本数为 n n n。将权重参数用向量 w = [ w 1 , w 2 ] \boldsymbol{w} = [w_1, w_2] w=[w1,w2]表示,带有 L 2 L_2 L2范数惩罚项的新损失函数为

ℓ ( w 1 , w 2 , b ) + λ 2 n ∥ w ∥ 2 , \ell(w_1, w_2, b) + \frac{\lambda}{2n} \|\boldsymbol{w}\|^2, (w1,w2,b)+2nλw2,

其中超参数 λ > 0 \lambda > 0 λ>0。当权重参数均为0时,惩罚项最小。当 λ \lambda λ较大时,惩罚项在损失函数中的比重较大,这通常会使学到的权重参数的元素较接近0。当 λ \lambda λ设为0时,惩罚项完全不起作用。上式中 L 2 L_2 L2范数平方 ∥ w ∥ 2 \|\boldsymbol{w}\|^2 w2展开后得到 w 1 2 + w 2 2 w_1^2 + w_2^2 w12+w22。有了 L 2 L_2 L2范数惩罚项后,在小批量随机梯度下降中,我们将线性回归一节中权重 w 1 w_1 w1 w 2 w_2 w2的迭代方式更改为

w 1 ← ( 1 − η λ ∣ B ∣ ) w 1 − η ∣ B ∣ ∑ i ∈ B x 1 ( i ) ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) , w 2 ← ( 1 − η λ ∣ B ∣ ) w 2 − η ∣ B ∣ ∑ i ∈ B x 2 ( i ) ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) . \begin{aligned} w_1 &\leftarrow \left(1- \frac{\eta\lambda}{|\mathcal{B}|} \right)w_1 - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}}x_1^{(i)} \left(x_1^{(i)} w_1 + x_2^{(i)} w_2 + b - y^{(i)}\right),\\ w_2 &\leftarrow \left(1- \frac{\eta\lambda}{|\mathcal{B}|} \right)w_2 - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}}x_2^{(i)} \left(x_1^{(i)} w_1 + x_2^{(i)} w_2 + b - y^{(i)}\right). \end{aligned} w1w2(1Bηλ)w1BηiBx1(i)(x1(i)w1+x2(i)w2+by(i)),(1Bηλ)w2BηiBx2(i)(x1(i)w1+x2(i)w2+by(i)).

可见, L 2 L_2 L2范数正则化令权重 w 1 w_1 w1 w 2 w_2 w2先自乘小于1的数,再减去不含惩罚项的梯度。因此, L 2 L_2 L2范数正则化又叫权重衰减。权重衰减通过惩罚绝对值较大的模型参数为需要学习的模型增加了限制,这可能对过拟合有效。实际场景中,我们有时也在惩罚项中添加偏差元素的平方和。

3.12.2 高维线性回归实验

下面,我们以高维线性回归为例来引入一个过拟合问题,并使用权重衰减来应对过拟合。设数据样本特征的维度为 p p p。对于训练数据集和测试数据集中特征为 x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,,xp的任一样本,我们使用如下的线性函数来生成该样本的标签:

y = 0.05 + ∑ i = 1 p 0.01 x i + ϵ y = 0.05 + \sum_{i = 1}^p 0.01x_i + \epsilon y=0.05+i=1p0.01xi+ϵ

其中噪声项 ϵ \epsilon ϵ服从均值为0、标准差为0.01的正态分布。为了较容易地观察过拟合,我们考虑高维线性回归问题,如设维度 p = 200 p=200 p=200;同时,我们特意把训练数据集的样本数设低,如20。

%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2ln_train, n_test, num_inputs = 20, 100, 200
true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05features = torch.randn((n_train + n_test, num_inputs))
labels = torch.matmul(features, true_w) + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

3.12.3 从零开始实现

下面先介绍从零开始实现权重衰减的方法。我们通过在目标函数后添加 L 2 L_2 L2范数惩罚项来实现权重衰减。

3.12.3.1 初始化模型参数

首先,定义随机初始化模型参数的函数。该函数为每个参数都附上梯度。

def init_params():w = torch.randn((num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

3.12.3.2 定义 L 2 L_2 L2范数惩罚项

下面定义 L 2 L_2 L2范数惩罚项。这里只惩罚模型的权重参数。

def l2_penalty(w):return (w**2).sum() / 2

3.12.3.3 定义训练和测试

下面定义如何在训练数据集和测试数据集上分别训练和测试模型。与前面几节中不同的是,这里在计算最终的损失函数时添加了 L 2 L_2 L2范数惩罚项。

batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_lossdataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)def fit_and_plot(lambd):w, b = init_params()train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:# 添加了L2范数惩罚项l = loss(net(X, w, b), y) + lambd * l2_penalty(w)l = l.sum()if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()d2l.sgd([w, b], lr, batch_size)train_ls.append(loss(net(train_features, w, b), train_labels).mean().item())test_ls.append(loss(net(test_features, w, b), test_labels).mean().item())d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('L2 norm of w:', w.norm().item())

3.12.3.4 观察过拟合

接下来,让我们训练并测试高维线性回归模型。当lambd设为0时,我们没有使用权重衰减。结果训练误差远小于测试集上的误差。这是典型的过拟合现象。

fit_and_plot(lambd=0)

输出:

L2 norm of w: 15.114808082580566

在这里插入图片描述

3.12.3.5 使用权重衰减

下面我们使用权重衰减。可以看出,训练误差虽然有所提高,但测试集上的误差有所下降。过拟合现象得到一定程度的缓解。另外,权重参数的 L 2 L_2 L2范数比不使用权重衰减时的更小,此时的权重参数更接近0。

fit_and_plot(lambd=3)

输出:

L2 norm of w: 0.035220853984355927

在这里插入图片描述

3.12.4 简洁实现

这里我们直接在构造优化器实例时通过weight_decay参数来指定权重衰减超参数。默认下,PyTorch会对权重和偏差同时衰减。我们可以分别对权重和偏差构造优化器实例,从而只对权重衰减。

def fit_and_plot_pytorch(wd):# 对权重参数衰减。权重名称一般是以weight结尾net = nn.Linear(num_inputs, 1)nn.init.normal_(net.weight, mean=0, std=1)nn.init.normal_(net.bias, mean=0, std=1)optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 对权重参数衰减optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr)  # 不对偏差参数衰减train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:l = loss(net(X), y).mean()optimizer_w.zero_grad()optimizer_b.zero_grad()l.backward()# 对两个optimizer实例分别调用step函数,从而分别更新权重和偏差optimizer_w.step()optimizer_b.step()train_ls.append(loss(net(train_features), train_labels).mean().item())test_ls.append(loss(net(test_features), test_labels).mean().item())d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('L2 norm of w:', net.weight.data.norm().item())

与从零开始实现权重衰减的实验现象类似,使用权重衰减可以在一定程度上缓解过拟合问题。

fit_and_plot_pytorch(0)

输出:

L2 norm of w: 12.86785888671875

在这里插入图片描述

fit_and_plot_pytorch(3)

输出:

L2 norm of w: 0.09631537646055222

在这里插入图片描述

小结

  • 正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。
  • 权重衰减等价于 L 2 L_2 L2范数正则化,通常会使学到的权重参数的元素较接近0。
  • 权重衰减可以通过优化器中的weight_decay超参数来指定。
  • 可以定义多个优化器实例对不同的模型参数使用不同的迭代方法。

注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2808850.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

445. 两数相加 II(Java)

目录 题目描述:输入:输出:代码实现: 题目描述: 给你两个 非空 链表来代表两个非负整数。数字最高位位于链表开始位置。它们的每个节点只存储一位数字。将这两数相加会返回一个新的链表。 你可以假设除了数字 0 之外&am…

SpringBoot Admin 详解

SpringBoot Admin 详解 一、Actuator 详解1.Actuator原生端点1.1 监控检查端点:health1.2 应用信息端点:info1.3 http调用记录端点:httptrace1.4 堆栈信息端点:heapdump1.5 线程信息端点:threaddump1.6 获取全量Bean的…

人力资源管理信息化系统如何支持企业开展管理诊断

人力资源顾问有限公司致力于帮助企业开展人力资源管理方面的各项提升改进工作,在长期的咨询工作中,最常听到企业提到的问题莫过于管理诊断方面的问题,事实上,很多企业在日常工作中,都意识到企业内部存在管理方面的问题…

String类常用方法(Java)

String类的常见方法(笔记) 1. charAt(int index) 返回此字符串中指定索引处的字符。 String str "hello"; char ch str.charAt(1); // 获取字符串中索引为1的字符,结果为 e2. compareTo(String anotherString)按字典顺序比较两个…

蓝桥杯算法赛 第 6 场 小白入门赛 解题报告 | 珂学家 | 简单场 + 元宵节日快乐

前言 整体评价 因为适逢元宵节,所以这场以娱乐为主。 A. 元宵节快乐 题型: 签到 节日快乐,出题人也说出来自己的心愿, 祝大家AK快乐! import java.util.Scanner;public class Main {public static void main(String[] args) {System.out.println(&qu…

智能枪弹柜管理系统-智能枪弹管理系统DW-S306

随着社会的发展和治安形势的日益严峻,对于枪弹的管理变得尤为重要。传统的手工记录和存放方式已经无法满足现代化、高效化、安全化的需求。因此,智能枪弹柜管理系统应运而 生。 在建设万兆主干、千兆终端的监控专网的基础上,弹药库安全技术…

python实现线下缓存最优算法

对于现代计算机为了加快数据存储速度,一般会采用多级缓存的方法,以最简单的二级缓存来说,数据会存放在两个地方,一个地方就是存在内存当中,另一个存放的地方就是存放在硬盘当中,但是这两个地方数据读取的速度是完全不同的。 而CPU从内存中读取数据的速度是要远远快与从硬…

Python Web开发记录 Day3:BootStrap

名人说:莫道桑榆晚,为霞尚满天。——刘禹锡(刘梦得,诗豪) 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 目录 三、BootStrap1、BootStrap-初体验2、BootStrap…

【论文精读】ConvNeXt

摘要 Vision Transformer是当前最先进的图像分类模型,但普通ViT在应用于一般计算机视觉任务(如目标检测和语义分割)时面临困难。故后来的分层Vision Transformer(如Swin Transformer)通过引入了几种卷积网络先验&#…

python程序设计基础:异常处理结构与程序调试、测试

第八章:异常处理结构与程序调试、测试 简单地说,异常是指程序运行时引发的错误,引发错误的原因有很多例如除零、下标越界、文件不存在、网络异常、类型错误、名字错误、字典键错误、磁盘空间不足,等等。 如果这些错误得不到正确的处理将会导致程序终止运行,而合理…

抖音视频下载工具|视频内容提取软件

引言部分: 针对抖音视频下载需求,我们团队自豪推出一款功能强大的工具,旨在解决用户获取抖音视频繁琐问题的困扰。我们通过基于C#开发的工具,让用户能够轻松通过关键词搜索实现自动批量抓取视频,并根据需求进行选择性批…

MAC地址学习和老化

MAC地址学习过程 一般情况下,MAC地址表是设备根据收到的数据帧里的源MAC地址自动学习而建立的。 图1 MAC地址学习示意图 如图1,HostA向SwitchA发送数据时,SwitchA从数据帧中解析出源MAC地址(即HostA的MAC地址)和VLAN…

PMP项目管理考试要注意些什么?

PMP考试和PMP备考过程中应该注意哪些问题? PMP备考完成后就要迎接实战考试了,考试前千万不要有多余的想法,顺其自然就行了,我想大家各种紧张、各种忧虑的原因大抵是因为考试成本考,担心考不过,其实只要你在…

excel标记文本中的关键词加红加粗

任务: 有这么一张表,关键词为 word,文本内容为 text,现在想把 text 中的 word 标红加粗,如果数据量少,文本段手动标还可以,多起来就不太方便了 代码: import pandas as pd import x…

ChatGPT助您提升求职技能

目录 ChatGPT可以作为求职技能的学习和提升平台 ChatGPT可以帮助求职者提升沟通和表达能力 ChatGPT还可以帮助求职者提升问题解决能力和创新能力 ChatGPT还可以帮助求职者建立自信心和自我推销能力 随着科技的迅速发展,人们的生活方式和工作方式也在不断地变革。…

nginx高级配置详解

目录 一、网页的状态页 1、状态页的基本配置 2、搭配验证模块使用 3、结合白名单使用 二、nginx 第三方模块 1、echo模块 1.1 编译安装echo模块 1.2 配置echo模块 三、nginx变量 1、内置变量 2、自定义变量 四、自定义图标 五、自定义访问日志 1、自定义日志格式…

安全测试:史上最全的攻防渗透信息收集方法、工具!

信息收集的意义 信息收集对于渗透测试前期来说是非常重要的。正所谓,知己知彼百战不殆,信息收集是渗透测试成功的保障,只有我们掌握了目标网站或目标主机足够多的信息之后,才能更好地进行渗透测试。 信息收集的方式可以分为两种…

虚 拟 化原理

1 概念: ①通俗理解: 虚拟化是在硬件和操作系统之间的实践 ②通过对计算机的服务层级的理解,理解虚拟化概念 抽离层级之间的依赖关系(服务器虚拟化) 2 虚拟化分类 ①按架构分类 ◆寄居架构:装在操作系统上…

JavaScript流程控制

文章目录 1. 顺序结构2. 分支结构2.1 if 语句2.2 if else 双分支语句2.3 if else if 多分支语句三元表达式 2.4 switch 语句switch 语句和 if else if语句区别 3. 循环结构3.1 for 循环断点调试 3.2 双重 for 循环3.3 while 循环3.4 do while 循环3.5 contiue break 关键字 4. …

蓝桥杯-乘积最大

原题链接:用户登录 题目描述 今年是国际数学联盟确定的“2000 --世界数学年”,又恰逢我国著名数学家华罗庚先生诞辰 90 周年。在华罗庚先生的家乡江苏金坛,组织了一场别开生面的数学智力竞赛的活动,你的一个好朋友 XZ 也有幸得以…