《动手学深度学习(PyTorch版)》笔记7.6

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.6 Residual Networks(ResNet)

随着我们设计越来越深的网络,深刻理解“新添加的层如何提升神经网络的性能”变得至关重要。

7.6.1 Function Class

首先,假设有一类特定的神经网络架构 F \mathcal{F} F,它包括学习速率和其他超参数设置。对于所有 f ∈ F f \in \mathcal{F} fF,存在一些参数集(例如权重和偏置),这些参数可以通过在合适的数据集上进行训练而获得。现在假设 f ∗ f^* f是我们真正想要找到的函数,如果是 f ∗ ∈ F f^* \in \mathcal{F} fF,那我们可以轻而易举的训练得到它,但通常我们不会那么幸运。我们将尝试找到一个函数 f F ∗ f^*_\mathcal{F} fF,这是我们在 F \mathcal{F} F中的最佳选择。例如,给定一个具有 X \mathbf{X} X特性和 y \mathbf{y} y标签的数据集,我们可以尝试通过解决以下优化问题来找到它:

f F ∗ : = a r g m i n f L ( X , y , f ) ,  f ∈ F . f^*_\mathcal{F} := \mathop{\mathrm{argmin}}_f L(\mathbf{X}, \mathbf{y}, f) \text{ , } f \in \mathcal{F}. fF:=argminfL(X,y,f) , fF.

为了得到更近似真正 f ∗ f^* f的函数,唯一合理的可能性是设计一个更强大的架构 F ′ \mathcal{F}' F。换句话说,我们预计 f F ′ ∗ f^*_{\mathcal{F}'} fF f F ∗ f^*_{\mathcal{F}} fF“更近似”。然而,如果 F ⊈ F ′ \mathcal{F} \not\subseteq \mathcal{F}' FF,则无法保证新的体系“更近似”。事实上, f F ′ ∗ f^*_{\mathcal{F}'} fF可能更糟:如下图所示,对于非嵌套函数(non-nested function)类,较复杂的函数类并不总是向“真”函数 f ∗ f^* f靠拢(复杂度由 F 1 \mathcal{F}_1 F1 F 6 \mathcal{F}_6 F6递增)。在下图的左边,虽然 F 3 \mathcal{F}_3 F3 F 1 \mathcal{F}_1 F1更接近 f ∗ f^* f,但 F 6 \mathcal{F}_6 F6却离的更远了。相反,对于下图右边的嵌套函数(nested function)类 F 1 ⊆ … ⊆ F 6 \mathcal{F}_1 \subseteq \ldots \subseteq \mathcal{F}_6 F1F6,我们可以避免上述问题。
在这里插入图片描述

因此,只有当较复杂的函数类包含较小的函数类时,我们才能确保提高它们的性能。对于深度神经网络,如果我们能将新添加的层训练成恒等映射(identity function) f ( x ) = x f(\mathbf{x}) = \mathbf{x} f(x)=x,新模型和原模型将同样有效。同时,由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。针对这一问题,何恺明等人提出了残差网络(ResNet)。其核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。于是,残差块(residual blocks)便诞生了,这个设计对如何建立深层神经网络产生了深远的影响。

7.6.2 Residual Blocks

在这里插入图片描述

如上图所示,假设我们的原始输入为 x x x,而希望学出的理想映射为 f ( x ) f(\mathbf{x}) f(x)。上图左边是一个正常块,虚线框中的部分需要直接拟合出该映射 f ( x ) f(\mathbf{x}) f(x),而右边是ResNet的基础架构–残差块(residual block),虚线框中的部分则需要拟合出残差映射 f ( x ) − x f(\mathbf{x}) - \mathbf{x} f(x)x。残差映射在现实中往往更容易优化。以恒等映射作为理想映射 f ( x ) f(\mathbf{x}) f(x),只需将上图右边虚线框内上方的加权运算(如仿射)的权重和偏置参数设成0,那么 f ( x ) f(\mathbf{x}) f(x)即为恒等映射。实际上,当理想映射 f ( x ) f(\mathbf{x}) f(x)极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。在残差块中,输入可通过跨层数据线路更快地向前传播,且可以避免某些梯度消失或梯度爆炸的问题。

在这里插入图片描述

ResNet沿用了VGG完整的 3 × 3 3\times 3 3×3卷积层设计。残差块里首先有2个有相同输出通道数的 3 × 3 3\times 3 3×3卷积层,每个卷积层后接一个批量规范化层和ReLU激活函数,然后我们通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。如果想改变通道数,就需要引入一个额外的 1 × 1 1\times 1 1×1卷积层来将输入变换成需要的形状后再做相加运算。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import matplotlib.pyplot as pltclass Residual(nn.Module):  #@savedef __init__(self, input_channels,num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

如下图所示,此代码生成两种类型的网络:当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出;当use_1x1conv=True时,使用 1 × 1 1 \times 1 1×1卷积调整通道和分辨率。

在这里插入图片描述

blk = Residual(3,3)#输入和输出形状一致
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
print(Y.shape)blk = Residual(3,6, use_1x1conv=True, strides=2)#增加输出通道数的同时,减半输出的高和宽
print(blk(X).shape)#定义ResNet的模块
#b2-b5各有4个卷积层(不包括恒等映射的1x1卷积层),加上第一个7x7卷积层和最后一个全连接层,共有18层,因此这种模型通常被称为ResNet-18
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))def resnet_block(input_channels, num_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkb2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2774273.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

CC工具箱使用指南:【获取字段的所有唯一值】

一、简介 这个工具的目的是获取选定要素图层的字段的所有唯一值。 一般就是用于查看,比如说看一下规划用地有多少种地类,都是哪些地类。 二、工具参数介绍 点击【信息获取】组里的【获取字段的所有唯一值】工具: 即可打开下面的工具框界面…

Codeforces Round 923 (Div. 3)E. Klever Permutation 找规律,有共同区间

Problem - E - Codeforces 目录 Source of idea: 思路: 代码: 另一个up的找规律的解法: Source of idea: Codeforces Round 923(A-F题解) - 哔哩哔哩 (bilibili.com) 思路: 上面up分析的很好。两个相邻区间也就端点不一样&…

干货总结!Dockerfile编写优秀实践

Dockerfile 优秀实践 1. 善用.dockerignore文件 Docker 是CS架构,这就意味着 Client 和 Server 可以在不同的主机上。在构建镜像的时候,Client 会把所有需要的文件打包发送给 Server,这些发送的文件叫做 build context 默认情况下&#xf…

深度学习的新进展:解析技术演进与应用前景

深度学习的新进展:解析技术演进与应用前景 深度学习,作为人工智能领域的一颗璀璨明珠,一直以来都在不断刷新我们对技术和未来的认知。随着时间的推移,深度学习不断迎来新的进展,这不仅推动了技术的演进,也…

百面嵌入式专栏(面试题)C语言面试题22道

沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们将介绍C语言相关面试题 。 宏定义是在编译的哪个阶段被处理的?答案:宏定义是在编译预处理阶段被处理的。 解读:编译预处理:头文件包含、宏替换、条件编译、去除注释、添加行号。 写一个“标准”宏MIN,这个…

FPGA高端项目:解码索尼IMX327 MIPI相机转USB3.0 UVC 输出,提供FPGA开发板+2套工程源码+技术支持

目录 1、前言免责声明 2、相关方案推荐我这里已有的 MIPI 编解码方案 3、本 MIPI CSI-RX IP 介绍4、个人 FPGA高端图像处理开发板简介5、详细设计方案设计原理框图IMX327 及其配置MIPI CSI RX图像 ISP 处理图像缓存UVC 时序USB3.0输出架构FPGA逻辑设计工程源码架构SDK软件工程源…

2023年ABC123公众号年刊下载(PDF电子书)

Part1 前言 大家好,我是ABC_123。2023年公众号正式更名为"希潭实验室"。除了分享日常红队攻防、渗透测试技术文章之外,重点加强了APT案例分析方面的内容。公众号关注度得到进一步提升,关注人数已达到3万5千人。原计划在2023年编写…

统一身份认证系统架构设计与实践总结

随着互联网的快速发展和应用的普及,人们在各个网站和应用上需要不同的账号和密码进行身份认证。为了解决这个问题,统一身份认证系统应运而生。本文将总结统一身份认证系统的架构设计与实践经验,帮助读者了解如何设计和实现一个高效、安全的统…

2024幻兽帕鲁服务器多少钱一套?

2024年幻兽帕鲁服务器价格表更新,阿里云、腾讯云和华为云Palworld服务器报价大全,4核16G幻兽帕鲁专用服务器阿里云26元、腾讯云32元、华为云26元,阿腾云atengyun.com分享幻兽帕鲁服务器优惠价格表,多配置报价: 幻兽帕鲁…

第三百一十三回

文章目录 1. 概念介绍2. 实现方法2.1 obscureText属性2.2 decoration属性 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何实现倒计时功能"相关的内容,本章回中将介绍如何实现密码输入框.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍…

计划任务功能优化,应用商店上架软件超过100款,1Panel开源面板v1.9.6发布

2024年2月7日,现代化、开源的Linux服务器运维管理面板1Panel正式发布v1.9.6版本。 在v1.9.5和v1.9.6这两个小版本中,1Panel针对计划任务等功能进行了多项优化和Bug修复。此外,1Panel应用商店新增了3款应用,上架精选软件应用超过1…

算法随想录第五十二天打卡|300.最长递增子序列 , 674. 最长连续递增序列 , 718. 最长重复子数组

300.最长递增子序列 今天开始正式子序列系列,本题是比较简单的,感受感受一下子序列题目的思路。 视频讲解:动态规划之子序列问题,元素不连续!| LeetCode:300.最长递增子序列_哔哩哔哩_bilibili 代码随想录…

【python】绘制春节烟花

一、Pygame库春节烟花示例 下面是一个使用Pygame实现的简单春节烟花效果的示例代码。请注意,运行下面的代码之前,请确保计算机上已经安装了Pygame库。 import pygame import random import math from pygame.locals import *# 初始化pygame pygame.ini…

基于麻雀优化算法优化XGBoost参数的优化控制策略

目录 一、背景 二、算法流程图 三、附录 一、背景 为提高极端梯度提升(Extreme Gradient Boosting, XGBoost)集成算法在时间预测、信贷风险预测、工件参数预测、故障诊断预测等方面中的准确性,研究者提出了一种改进的麻雀算法(…

【我与Java的成长记】之String类详解

系列文章目录 能看懂文字就能明白系列 C语言笔记传送门 Java笔记传送门 🌟 个人主页:古德猫宁- 🌈 信念如阳光,照亮前行的每一步 文章目录 系列文章目录🌈 *信念如阳光,照亮前行的每一步* 前言一、字符串构…

深入浅出:Golang的Crypto/SHA256库实战指南

深入浅出:Golang的Crypto/SHA256库实战指南 介绍crypto/sha256库概览主要功能应用场景库结构和接口实例 基础使用教程字符串哈希化文件哈希化处理大型数据 进阶使用方法增量哈希计算使用Salt增强安全性多线程哈希计算 实际案例分析案例一:安全用户认证系…

【芯片设计- RTL 数字逻辑设计入门 13 -- generate_for 和 for】

文章目录 generate_forverilog codetestbench code仿真波形 for 循环verilog code仿真波形错误小结 generate_for 在某个module中包含了很多相似的连续赋值语句,请使用generata…for语句编写代码,替代该语句,要求不能改变原module的功能。 …

假设检验的过程

假设检验的核心思想是小概率事件在一次实验中不可能发生,假设检验就是利用小概率事件的发生进行反正。学习假设检验,有几个概念不能跳过,原假设、p值 1.原假设 假设检验的基本过程如下: 1)做出一个假设H0&#xff0c…

IEC 104电力规约详细解读(三) - 遥信

1.功能简述 遥信,、即状态量,是为了将断路器、隔离开关、中央信号等位置信号上送到监控后台的信息。遥信信息包括:反应电网运行拓扑方式的位置信息。如断路器状态、隔离开关状态;反应一次二次设备工作状况的运行信息,如…

豪掷770亿!华为员工集体“分红大狂欢”:至少14万人受益

豪掷770亿!华为员工集体“分红大狂欢”:至少14万人受益 近日,华为宣布了其2023年度分红计划,总金额高达770.85亿元,预计至少将惠及14万员工。这一消息引发了广泛关注和热议,成为业界的一大亮点。作为中国领…