【深度学习笔记】3_14 正向传播、反向传播和计算图

3.14 正向传播、反向传播和计算图

前面几节里我们使用了小批量随机梯度下降的优化算法来训练模型。在实现中,我们只提供了模型的正向传播(forward propagation)的计算,即对输入计算模型输出,然后通过autograd模块来调用系统自动生成的backward函数计算梯度。基于反向传播(back-propagation)算法的自动求梯度极大简化了深度学习模型训练算法的实现。本节我们将使用数学和计算图(computational graph)两个方式来描述正向传播和反向传播。具体来说,我们将以带 L 2 L_2 L2范数正则化的含单隐藏层的多层感知机为样例模型解释正向传播和反向传播。

3.14.1 正向传播

正向传播是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出)。为简单起见,假设输入是一个特征为 x ∈ R d \boldsymbol{x} \in \mathbb{R}^d xRd的样本,且不考虑偏差项,那么中间变量

z = W ( 1 ) x , \boldsymbol{z} = \boldsymbol{W}^{(1)} \boldsymbol{x}, z=W(1)x,

其中 W ( 1 ) ∈ R h × d \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)Rh×d是隐藏层的权重参数。把中间变量 z ∈ R h \boldsymbol{z} \in \mathbb{R}^h zRh输入按元素运算的激活函数 ϕ \phi ϕ后,将得到向量长度为 h h h的隐藏层变量

h = ϕ ( z ) . \boldsymbol{h} = \phi (\boldsymbol{z}). h=ϕ(z).

隐藏层变量 h \boldsymbol{h} h也是一个中间变量。假设输出层参数只有权重 W ( 2 ) ∈ R q × h \boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)Rq×h,可以得到向量长度为 q q q的输出层变量

o = W ( 2 ) h . \boldsymbol{o} = \boldsymbol{W}^{(2)} \boldsymbol{h}. o=W(2)h.

假设损失函数为 ℓ \ell ,且样本标签为 y y y,可以计算出单个数据样本的损失项

L = ℓ ( o , y ) . L = \ell(\boldsymbol{o}, y). L=(o,y).

根据 L 2 L_2 L2范数正则化的定义,给定超参数 λ \lambda λ,正则化项即

s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) , s = \frac{\lambda}{2} \left(\|\boldsymbol{W}^{(1)}\|_F^2 + \|\boldsymbol{W}^{(2)}\|_F^2\right), s=2λ(W(1)F2+W(2)F2),

其中矩阵的Frobenius范数等价于将矩阵变平为向量后计算 L 2 L_2 L2范数。最终,模型在给定的数据样本上带正则化的损失为

J = L + s . J = L + s. J=L+s.

我们将 J J J称为有关给定数据样本的目标函数,并在以下的讨论中简称目标函数。

3.14.2 正向传播的计算图

我们通常绘制计算图来可视化运算符和变量在计算中的依赖关系。图3.6绘制了本节中样例模型正向传播的计算图,其中左下角是输入,右上角是输出。可以看到,图中箭头方向大多是向右和向上,其中方框代表变量,圆圈代表运算符,箭头表示从输入到输出之间的依赖关系。

在这里插入图片描述

图3.6 正向传播的计算图

3.14.3 反向传播

反向传播指的是计算神经网络参数梯度的方法。总的来说,反向传播依据微积分中的链式法则,沿着从输出层到输入层的顺序,依次计算并存储目标函数有关神经网络各层的中间变量以及参数的梯度。对输入或输出 X , Y , Z \mathsf{X}, \mathsf{Y}, \mathsf{Z} X,Y,Z为任意形状张量的函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X) Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y),通过链式法则,我们有

∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) , \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right), XZ=prod(YZ,XY),

其中 prod \text{prod} prod运算符将根据两个输入的形状,在必要的操作(如转置和互换输入位置)后对两个输入做乘法。

回顾一下本节中样例模型,它的参数是 W ( 1 ) \boldsymbol{W}^{(1)} W(1) W ( 2 ) \boldsymbol{W}^{(2)} W(2),因此反向传播的目标是计算 ∂ J / ∂ W ( 1 ) \partial J/\partial \boldsymbol{W}^{(1)} J/W(1) ∂ J / ∂ W ( 2 ) \partial J/\partial \boldsymbol{W}^{(2)} J/W(2)。我们将应用链式法则依次计算各中间变量和参数的梯度,其计算次序与前向传播中相应中间变量的计算次序恰恰相反。首先,分别计算目标函数 J = L + s J=L+s J=L+s有关损失项 L L L和正则项 s s s的梯度

∂ J ∂ L = 1 , ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1, \quad \frac{\partial J}{\partial s} = 1. LJ=1,sJ=1.

其次,依据链式法则计算目标函数有关输出层变量的梯度 ∂ J / ∂ o ∈ R q \partial J/\partial \boldsymbol{o} \in \mathbb{R}^q J/oRq

∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o . \frac{\partial J}{\partial \boldsymbol{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right) = \frac{\partial L}{\partial \boldsymbol{o}}. oJ=prod(LJ,oL)=oL.

接下来,计算正则项有关两个参数的梯度:

∂ s ∂ W ( 1 ) = λ W ( 1 ) , ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \boldsymbol{W}^{(1)}} = \lambda \boldsymbol{W}^{(1)},\quad\frac{\partial s}{\partial \boldsymbol{W}^{(2)}} = \lambda \boldsymbol{W}^{(2)}. W(1)s=λW(1),W(2)s=λW(2).

现在,我们可以计算最靠近输出层的模型参数的梯度 ∂ J / ∂ W ( 2 ) ∈ R q × h \partial J/\partial \boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h} J/W(2)Rq×h。依据链式法则,得到

∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) . \frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)}. W(2)J=prod(oJ,W(2)o)+prod(sJ,W(2)s)=oJh+λW(2).

沿着输出层向隐藏层继续反向传播,隐藏层变量的梯度 ∂ J / ∂ h ∈ R h \partial J/\partial \boldsymbol{h} \in \mathbb{R}^h J/hRh可以这样计算:

∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o . \frac{\partial J}{\partial \boldsymbol{h}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{h}}\right) = {\boldsymbol{W}^{(2)}}^\top \frac{\partial J}{\partial \boldsymbol{o}}. hJ=prod(oJ,ho)=W(2)oJ.

由于激活函数 ϕ \phi ϕ是按元素运算的,中间变量 z \boldsymbol{z} z的梯度 ∂ J / ∂ z ∈ R h \partial J/\partial \boldsymbol{z} \in \mathbb{R}^h J/zRh的计算需要使用按元素乘法符 ⊙ \odot

∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) . \frac{\partial J}{\partial \boldsymbol{z}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{h}}, \frac{\partial \boldsymbol{h}}{\partial \boldsymbol{z}}\right) = \frac{\partial J}{\partial \boldsymbol{h}} \odot \phi'\left(\boldsymbol{z}\right). zJ=prod(hJ,zh)=hJϕ(z).

最终,我们可以得到最靠近输入层的模型参数的梯度 ∂ J / ∂ W ( 1 ) ∈ R h × d \partial J/\partial \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d} J/W(1)Rh×d。依据链式法则,得到

∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) . \frac{\partial J}{\partial \boldsymbol{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{z}}, \frac{\partial \boldsymbol{z}}{\partial \boldsymbol{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}\right) = \frac{\partial J}{\partial \boldsymbol{z}} \boldsymbol{x}^\top + \lambda \boldsymbol{W}^{(1)}. W(1)J=prod(zJ,W(1)z)+prod(sJ,W(1)s)=zJx+λW(1).

3.14.4 训练深度学习模型

在训练深度学习模型时,正向传播和反向传播之间相互依赖。下面我们仍然以本节中的样例模型分别阐述它们之间的依赖关系。

一方面,正向传播的计算可能依赖于模型参数的当前值,而这些模型参数是在反向传播的梯度计算后通过优化算法迭代的。例如,计算正则化项 s = ( λ / 2 ) ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) s = (\lambda/2) \left(\|\boldsymbol{W}^{(1)}\|_F^2 + \|\boldsymbol{W}^{(2)}\|_F^2\right) s=(λ/2)(W(1)F2+W(2)F2)依赖模型参数 W ( 1 ) \boldsymbol{W}^{(1)} W(1) W ( 2 ) \boldsymbol{W}^{(2)} W(2)的当前值,而这些当前值是优化算法最近一次根据反向传播算出梯度后迭代得到的。

另一方面,反向传播的梯度计算可能依赖于各变量的当前值,而这些变量的当前值是通过正向传播计算得到的。举例来说,参数梯度 ∂ J / ∂ W ( 2 ) = ( ∂ J / ∂ o ) h ⊤ + λ W ( 2 ) \partial J/\partial \boldsymbol{W}^{(2)} = (\partial J / \partial \boldsymbol{o}) \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} J/W(2)=(J/o)h+λW(2)的计算需要依赖隐藏层变量的当前值 h \boldsymbol{h} h。这个当前值是通过从输入层到输出层的正向传播计算并存储得到的。

因此,在模型参数初始化完成后,我们交替地进行正向传播和反向传播,并根据反向传播计算的梯度迭代模型参数。既然我们在反向传播中使用了正向传播中计算得到的中间变量来避免重复计算,那么这个复用也导致正向传播结束后不能立即释放中间变量内存。这也是训练要比预测占用更多内存的一个重要原因。另外需要指出的是,这些中间变量的个数大体上与网络层数线性相关,每个变量的大小跟批量大小和输入个数也是线性相关的,它们是导致较深的神经网络使用较大批量训练时更容易超内存的主要原因。

小结

  • 正向传播沿着从输入层到输出层的顺序,依次计算并存储神经网络的中间变量。
  • 反向传播沿着从输出层到输入层的顺序,依次计算并存储神经网络中间变量和参数的梯度。
  • 在训练深度学习模型时,正向传播和反向传播相互依赖。

注:本节与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2808929.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

mysql的日志文件在哪?

阅读本文之前请参阅----MySQL 数据库安装教程详解(linux系统和windows系统) MySQL的日志文件通常包括错误日志、查询日志、慢查询日志和二进制日志等。这些日志文件的位置取决于MySQL的安装和配置。以下是一些常见的日志文件位置和如何找到它们&#xff…

电子签证小程序系统源码后台功能列表

基于ThinkPhp8.0uniapp 开发的电子签证小程序管理系统。能够真正帮助企业基于微信公众号H5、小程序、wap、pc、APP等,实现会员管理、数据分析,精准营销的电子商务管理系统。可满足企业新零售、批发、分销、预约、O2O、多店等各种业务需求,快速积累客户、…

从专业到大众:Sora如何颠覆传统视频制作模式

随着科技的飞速进步,人工智能(AI)技术正逐渐渗透到我们生活的方方面面。在视频制作领域,OpenAI推出的Sora模型为这一传统行业带来了前所未有的变革。Sora不仅改变了视频制作的技术门槛,更将视频制作从专业人士的手中解放出来,推向…

【线程池项目(三)】线程池CACHED模式的实现

在上一篇【线程池项目(二)】线程池FIXED模式的实现 中我们了解到到线程池fixed模式的大致实现原理,但对于一个比较完整的项目来说,我们还需要考虑到可能会发生的各种情况,比如用户提交的任务数可能在某一时刻急剧增加&…

贪心算法---前端问题

1、贪心算法—只关注于当前阶段的局部最优解,希望通过一系列的局部最优解来推出全局最优----但是有的时候每个阶段的局部最优之和并不是全局最优 例如假设你需要找给客户 n 元钱的零钱,而你手上只有若干种面额的硬币,如 1 元、5 元、10 元、50 元和 100…

【数据结构】排序(1)

目录 一、概念: 二、直接插入排序: 三、希尔排序: 四、直接选择排序: 五、堆排序: 六、冒泡排序: 一、概念: 排序的概念: 使一串记录,按照其中的某个或某些关键字…

Canvas实现打砖块

一.预览 二.代码 <!DOCTYPE html> <html lang"en"> <head><title>打砖块</title><style>#myCanvas {background: #eee; /* 设置画布的背景颜色为浅灰色 */display: block;margin: 0 auto; /* 使画布在页面中居中显示 */}</s…

高原制氧机的工作原理以及对高原地区生活质量的积极影响

在广袤的高原地区&#xff0c;空气稀薄&#xff0c;氧气含量相对较低&#xff0c;给当地居民和外来游客带来了不小的困扰。然而&#xff0c;随着科技的飞速进步&#xff0c;高原制氧机应运而生&#xff0c;成为改善高原生活质量的重要利器。恒业通将探讨高原制氧机的工作原理、…

【算法与数据结构】463、LeetCode岛屿的周长

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析&#xff1a;直接利用公式法&#xff0c;遇到一对相邻的陆地&#xff0c;总周长就减去2。那么周长公式为&#xff1…

微服务篇之任务调度

一、xxl-job的作用 1. 解决集群任务的重复执行问题。 2. cron表达式定义灵活。 3. 定时任务失败了&#xff0c;重试和统计。 4. 任务量大&#xff0c;分片执行。 二、xxl-job路由策略 1. FIRST&#xff08;第一个&#xff09;&#xff1a;固定选择第一个机器。 2. LAST&#x…

打包了一个QGIS3.34分享给大家

春节期间一时兴起编译打包了一个最新的QGIS版本QGIS3.34!秉承咱一贯理念&#xff0c;方便您使用也方便您不用&#xff01;该工具还是被打包为绿色版&#xff0c;即下即用&#xff0c;不用安装更无须卸载。微云的下载速度也比官方快很多&#xff0c;能大大节约您的时间提高您的工…

JavaAPI常用类02

目录 基本数据类型封装类 包装类常用属性方法 8中基本数据类型各自所对应的包装类 以下方法以java.lang.Integer为例 代码 运行 装箱和拆箱 装箱 何为装箱 代码 范围问题 代码 运行 拆箱 代码 String类 概述 代码 运行 创建形式 画图讲解 代码 运行 构造…

vscode使用restClient实现各种http请求

vscode使用restClient实现各种http请求 一&#xff0c;安装插件 首先&#xff0c;我们要在vscode的扩展中&#xff0c;搜索rest Client&#xff0c;然后安装它&#xff0c;这里我已经安装过了。 安装后&#xff0c;我们就可以使用rest client插件进行http各种操作了。 二&…

leetcode刷题-删除链表的倒数第N个节点(一次循环)

题目描述 解题思路 这几天玩的时间比较长&#xff0c;没有坚持更新。 解题思路很简单&#xff0c;也算是比较经典的问题。首先可以通道暴力解决&#xff0c;首先计算出来链表的长度&#xff0c;然后计算出来链表的长度&#xff0c;然后找到距离删除位置的前一个位置&#xff0…

131.乐理基础-快速识别音程(一)

上一个内容&#xff1a;130.乐理基础-倍增音程、倍减音程-CSDN博客 上一个内容里练习的答案&#xff1a; 开始不用数音数就可以辨别音程的方法&#xff0c;首先是不含升降号记号的两个音&#xff08;两个白键&#xff09;该怎样判断 方法的核心&#xff0c;就是音名中e-f和b-…

代码随想录算法训练营第28天 |第七章 回溯算法part04

学习目标&#xff1a; 93.复原IP地址 78.子集 90.子集II 学习内容&#xff1a; 93.复原IP地址 /class Solution { public:// string path;vector<string> result;bool isValid(const string& s, int start, int end) {if(start>end)return false;if(s[start]0&a…

SELF-RAG 论文详解

论文&#xff1a; https://arxiv.org/pdf/2310.11511.pdf code&#xff1a; https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_self_rag.ipynb?refblog.langchain.dev 相关工作 RAG 尽管 RAG 方法可以通过检索生成得到更为准确的结果&#xff0…

Codeforce Monsters Attack!(B题 前缀和)

题目描述&#xff1a; 思路&#xff1a; 本人第一次的想法是先杀血量低的第二次想法是先搞坐标近的第三次想法看到数据量这么大&#xff0c; 我先加个和看看貌似我先打谁都行&#xff0c;由此综合一下&#xff0c; 我们可以把每一个不同的坐标当作一轮从最小的坐标开始&#x…

老子云2024全站焕新,重塑3D轻量体验

3D模型当前应用广泛&#xff0c;正以惊人的速度实现数据增长&#xff0c;轻量化需求随之增多。老子云团队一直在探索如何借助自研轻量化技术的能力&#xff0c;打破用户模型处理思维惯性&#xff0c;构建更高效、实用、简单的体验范式&#xff0c;来帮助用户解决3D素材数据处理…

开源工具和框架

目录 开源工具和框架 一、 开源工具和框架 二、开源工具和框架在现代软件开发中的角色 1、基础设施建设&#xff1a; 2、开发效率提升&#xff1a; 3、代码质量保障&#xff1a; 4、技术创新&#xff1a; 三、广泛使用的开源项目分析 3.1、Linux 3.2、Git 3.3、Docke…