为什么要梯度累积

文章目录

    • 梯度累积
      • 什么是梯度累积
      • 如何理解理解梯度累积
        • 梯度累积的工作原理
      • 梯度累积的数学原理
        • 梯度累积过程
        • 如何实现梯度累积
      • 梯度累积的可视化

梯度累积

什么是梯度累积

随着深度学习模型变得越来越复杂,模型的训练通常需要更多的计算资源,特别是在训练期间需要更多的内存。在训练深度学习模型时,在硬件资源有限的情况下,很难使用大批量数据进行有效学习。大批量数据通常可以带来更好的梯度估计,但同时也需要大量的内存。

梯度累积是一种巧妙的技术,它允许在不增加内存需求的情况下,有效地使用更大的批量数据来训练深度学习模型。

如何理解理解梯度累积

梯度累积本质上涉及将大批量划分为较小的子批量,并在这些子批量上累积计算出的梯度。这一过程模拟了使用较大批量训练的情况。

梯度累积的工作原理

以下是梯度累积过程的逐步分解:

  1. 分而治之:将你的硬件无法处理的大批量划分为更小的、可管理的子批量。
  2. 累积梯度:不是在处理每个子批量后更新模型参数,而是在几个子批量上累积梯度。
  3. 参数更新:在处理了预定义数量的子批量后,使用累积的梯度来更新模型参数。

这种方法使得模型能够利用大批量的稳定性和收敛性,而不必提高内存成本。

梯度累积的数学原理

在这里插入图片描述

梯度累积过程

在深度学习模型中,一个完整的前向和反向传播过程如下:

  • 前向传播:数据通过神经网络,层层处理后得到预测结果。

  • 损失计算:使用损失函数计算预测结果与实际值之间的差异。以平方误差损失函数为例:

    L ( θ ) = 1 2 ( h ( x k ) − y k ) 2 L(\theta) = \frac{1}{2} (h(x_k) - y_k)^2 L(θ)=21(h(xk)yk)2

    这里 L ( θ ) L(\theta) L(θ) 表示损失函数, θ \theta θ 代表模型参数, h ( x k ) h(x_k) h(xk) 是对输入 x k x_k xk 的预测输出, y k y_k yk 是对应的真实输出。

  • 反向传播:计算损失函数相对于模型参数的梯度(对上式求导):

    ∇ θ L ( θ ) = ( h ( x k ) − y k ) ⋅ ∇ θ h ( x k ) \nabla_\theta L(\theta) = (h(x_k) - y_k) \cdot \nabla_\theta h(x_k) θL(θ)=(h(xk)yk)θh(xk)

  • 梯度累积:在传统的训练过程中,每完成一个批次的数据处理后就会更新模型参数。而在梯度累积中,梯度不是立即用来更新参数,而是累加多个小批次的梯度:

    G = ∑ i = 1 n ∇ θ L i ( θ ) G = \sum_{i=1}^{n} \nabla_{\theta} L_i(\theta) G=i=1nθLi(θ)

    这里 G G G 是累积梯度, L i ( θ ) L_i(\theta) Li(θ) 是第 i i i 个batch的损失函数。

  • 参数更新:累积足够的梯度后,使用以下公式更新参数:

    θ = θ − η ⋅ G \theta = \theta - \eta \cdot G θ=θηG
    其中 l r lr lr 是学习率,用于控制更新的步长。

如何实现梯度累积

以下是在 PyTorch 中实现梯度累积的示例:

# 模型定义
model = ...
optimizer = ...# 累积步骤数
accumulation_steps = 4for epoch in range(num_epochs):optimizer.zero_grad()for i, (inputs, labels) in enumerate(dataloader):outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()# 只有在处理足够数量的子批量后才更新参数if (i + 1) % accumulation_steps == 0:optimizer.step()optimizer.zero_grad()# 如果批量大小不是累积步数的倍数,确保在每个epoch结束时更新if (i + 1) % accumulation_steps != 0:optimizer.step()optimizer.zero_grad()

这个例子中,accumulation_steps 定义了在参数更新前需要累积的batch数量。

梯度累积的可视化

为了更好地理解梯度累积的影响,可视化可以非常有帮助。以下是一个例子,说明如何在神经网络中可视化梯度流,以监控梯度是如何被累积和应用的:

import matplotlib.pyplot as plt# 绘制梯度流动的函数
def plot_grad_flow(named_parameters):ave_grads = []layers = []for n, p in named_parameters:if (p.requires_grad) and ("bias" not in n):layers.append(n)ave_grads.append(p.grad.abs().mean())plt.plot(ave_grads, alpha=0.3, color="b")plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k")plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")plt.xlim(xmin=0, xmax=len(ave_grads))plt.xlabel("层")plt.ylabel("平均梯度")plt.title("网络中的梯度流")plt.grid(True)plt.show()# 在训练过程中或训练后调用此函数以可视化梯度流
plot_grad_flow(model.named_parameters())

参考资料:

  1. Gradient Accumulation Algorithm

  2. Performing gradient accumulation with 🤗 Accelerate

  3. 梯度累加(Gradient Accumulation)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3016525.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

self-attention 的 CUDA 实现及优化 (上)

self-attention 的 CUDA 实现及优化 (上) 导 读 self-attention 是 Transformer 中最关键、最复杂的部分,也是 Transformer 优化的核心环节。理解 self-attention ,对于深入理解 Transformer 具有关键作用,本篇主要就围绕 self-attention 展…

java+jsp+Oracle+Tomcat 记账管理系统论文(完整版)

⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️⬇️ ➡️点击免费下载全套资料:源码、数据库、部署教程、论文、答辩ppt一条龙服务 ➡️有部署问题可私信联系 ⬆️⬆️⬆️​​​​​​​⬆️…

ThingsBoard版本控制配合Gitee实现版本控制

1、概述 2、架构 3、导出设置 4、仓库 5、同步策略 6、扩展 7、案例 7.1、首先需要在Giitee上创建对应同步到仓库地址 ​7.2、giit仓库只能在租户层面进行配置 7.3、 配置完成后:检查访问权限。显示已成功验证仓库访问!表示配置成功 7.4、添加设…

鸿蒙OpenHarmony南向:【Hi3516标准系统入门(命令行方式)】

Hi3516标准系统入门(命令行方式) 注意: 从3.2版本起,标准系统不再针对Hi3516DV300进行适配验证,建议您使用RK3568进行标准系统的设备开发。 如您仍然需要使用Hi3516DV300进行标准系统相关开发操作,则可能会…

【Linux】文件内容相关的命令,补充:管道符

1、查看文件内容 (1-1)查看文件内容:cat,tac,head,tail 查看文件内容cat 文件名查看文件内容并显示行号cat -n 文件名倒着查看文件内容(从最后一行开始)tac 文件名查看文件前10行…

力扣hot100:543. 二叉树的直径/108. 将有序数组转换为二叉搜索树

一、543. 二叉树的直径 LeetCode:543. 二叉树的直径 二叉树的直径 二叉树的 直径 是指树中任意两个节点之间最长路径的 长度 。 遇到二叉树的问题很容易去直接用求解的目标去定义递归函数。但是仔细考虑,返回树的直径并不能向上传播。因此我们可以拆…

SolidWorks进行热力学有限元分析二、模型装配

1.先打开软件,新建装配体 2.选中你要装配的零件,直接导入就行 3.鼠标点击左键直接先放进去 4.开始装配,点配合 5.选择你要接触的两个面,鼠标右键确定,然后把剩下的面对齐一下就行了 6.搞定

学习《现代密码学——基于安全多方计算协议的研究》 第一章

目录 前言 第1章 绪论 1.1 密码学的发展历史 1.2 现代密码学体制 1.3 现代密码学与安全多方计算 前言 近几年来,云计算、物联网、移动互联网等新概念、新技术被先后提出,促使信息技术飞速发展。同时,人类生活、沟通方式也随着新技术的普及…

泰克MDO3024示波器如何调整衰减倍数?

泰克MDO3024示波器是一款高性能的数字示波器,具备多种功能和调节选项,可以满足各种测试需求。其中一个重要的调节选项就是调整衰减倍数,通过调整衰减倍数,可以改变示波器的灵敏度和测量范围,帮助我们更好地观察和分析信…

奇诡 matlab 小 bug matlab git需要记录的改动太多

似乎是我有一次添加了太多的路径之后的事情。但是不敢说一定是这个导致的: 症状:只要对文本进行任何编辑操作,工作区就会出现"Processing … Cancel"的提示,如果不管的话这个提示不会消失,同时matlab变得越来…

【进程终止】退出信号 | 三种退出情况 | 如何进程终止returnexit_exit

目录 退出码 退出信号 进程终止情况3 如何进程终止 return退出 库函数exit 系统调用函数_exit ​exit和_exit的区别缓冲区 exit _exit 退出码 回顾上篇 代码跑完,结果正确(退出码为0)代码跑完,结果不正确(退…

springboot项目组合定时器schedule注解实现定时任务

springboot项目组合定时器schedule注解实现定时任务! 创建好springboot项目后,需要在启动类上增加注解开启定时器任务 下图所示: 增加这个注解,启动项目, package com.example.scheduledemo.util;import org.springf…

Linux进程通信-信号

信号概念 信号是 Linux 进程间通信的最古老的方式之一,是事件发生时对进程的通知机制,有时也称之为软件中断,它是在软件层次上对中断机制的一种模拟,是一种异步通信的方式。信号 可以导致一个正在运行的进程被另一个正在运行的异…

企业怎样进行IT外包以及IT外包服务内容

在数字化时代的浪潮中,企业逐渐认识到信息技术的关键作用,特别是制造业基地对于IT外包和运维服务的需求持续增长。然而,在诸多可供选择的IT外包和运维方案中,企业如何推动与IT外包公司的合作?本文将深入介绍IT外包方案…

nginx 启动,查看,停止

nginx 启动,查看,停止 启动 start nginx 查看是否启动成功 tasklist | findstr nginx 停止 nginx -s stop 测试配置文件的语法是否有误 nginx -t 重启nginx nginx-s reload

网络安全之动态路由OSPF基础

OSPF:开放式最短路径优先协议。 1、协议使用范围:IGP。 2、协议算法特点:链路状态型路由协议。 3、协议是否传递网络掩码:传递网络掩码(无类别的路由协议)。 4、协议封装:基于IP协议封装&am…

第六代移动通信介绍、无线网络类型、白皮书

关于6G 即第六代移动通信的介绍, 图解通信原理与案例分析-30:6G-天地互联、陆海空一体、全空间覆盖的超宽带移动通信系统_6g原理-CSDN博客文章浏览阅读1.7w次,点赞34次,收藏165次。6G 即第六代移动通信,6G 将在5G 的基…

《QT实用小工具·六十》Qt 多列时间轴控件

1、概述 源码放在文章末尾 Qt 多列时间轴控件。 可与多段字符串格式自由转换,也可手动添加列表项。 专门用来以时间轴作为事件线发展顺序的故事大纲。 特点 时间背包功能:记录所有物品或属性发生的变化,随时回溯 时间可输入任意内容&…

【区块链】智能合约简介

智能合约起源 智能合约这个术语至少可以追溯到1995年,是由多产的跨领域法律学者尼克萨博(NickSzabo)提出来的。他在发表在自己的网站的几篇文章中提到了智能合约的理念。他的定义如下:“一个智能合约是一套以数字形式定义的承诺&a…

【C++STL详解(八)】--------stack和queue的模拟实现

目录 前言 一、stack模拟实现 二、queue的模拟实现 前言 前面也介绍了stack和queue的常见接口,我们也知道stack和queue实际上是一种容器适配器,它们只不过是对底层容器的接口进行封装而已,所以模拟实现起来比较简单!一起来看看是…