PyTorch 深度学习 || PyTorch 编程基础

PyTorch 编程基础

文章目录

  • PyTorch 编程基础
    • 1. backword 求梯度
    • 2. 常用的激活函数
      • 2.1 Sigmoid 函数
      • 2.2 ReLu 激活函数
      • 2.3 Leakly ReLu 激活函数
    • 2. 常用损失函数
      • 2.1 均方误差损失函数
      • 2.2 L1范数误差损失函数
      • 2.3 交叉熵损失函数
    • 3. 优化器

1. backword 求梯度

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(x, w) 
b = torch.add(w, 1)
y = torch.mul(a, b) # y=(x+w)(w+1)
y.backward() # 分别求出两个自变量的导数print(w.grad) # (w+1)+ (x+w) = x+2w+1 = 5
print(x.grad) # w+1 = 2

tensor([5.])

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
for i in range(3):a = torch.add(x, w)b = torch.add(w, 1)y = torch.mul(a, b) # y=(x+w)(w+1)y.backward() # (w+1)+(x+w) = x+2w+1 = 5print(w.grad) # 梯度在循环过程中进行了累加

tensor([5.])
tensor([10.])
tensor([15.])

2. 常用的激活函数

最常见的神经网络连接是全连接,这种连接是一种线性加权求和形式的线性运算。当没有激活函数的情况下,随着神经网络层数的增加,这种多层的线性运算和单层的线性运算在本质上没有任何差别。因此,激活函数无论从计算还是从生物解释上都非常重要。以下是几种常用的激活函数:

2.1 Sigmoid 函数

sigmoid是激活函数的一种,它会将样本值映射到0到1之间。sigmoid的公式如下:

y = 1 1 + e − input y=\frac{1}{1+e^{-\text{input}}} y=1+einput1

import numpy as np
import matplotlib.pyplot as pltdef sigmoid(x):return 1. / (1. + np.exp(-x))def plot_sigmoid():x = np.arange(-10, 10, 0.1)y = sigmoid(x)plt.plot(x, y)plt.show()if __name__ == '__main__':plot_sigmoid()

在这里插入图片描述

可以直接使用 PyTorch 自带的 Sigmoid 函数

import torch.nn as nn
import torch#取一组满足标准正态分布的随机数构成3*3的张量
t1 = torch.randn(3,3)
m = nn.Sigmoid()
t2 = m(t1)
print(t1)
print(t2)

tensor([[-1.1227, 0.8743, 0.7674],
[ 0.9877, 0.1209, 1.0413],
[ 0.2607, 0.6298, -0.1863]])
tensor([[0.2455, 0.7056, 0.6830],
[0.7286, 0.5302, 0.7391],
[0.5648, 0.6524, 0.4536]])

优点:

(1)便于求导的平滑函数

(2)能压缩数据,保证数据幅度不会有问题

(3)适合用于前向传播

缺点:

(1)容易出现梯度消失(gradient vanishing)的现象:当激活函数接近饱和区时,变化太缓慢,导数接近0,根据后向传递的数学依据是微积分求导的链式法则,当前导数需要之前各层导数的乘积,几个比较小的数相乘,导数结果很接近0,从而无法完成深层网络的训练。

(2)Sigmoid的输出不是0均值的:这会导致后层的神经元的输入是非0均值的信号,这会对梯度产生影响。以 f=sigmoid(wx+b)为例, 假设输入均为正数(或负数),那么对w的导数总是正数(或负数),这样在反向传播过程中要么都往正方向更新,要么都往负方向更新,导致有一种捆绑效果,使得收敛缓慢。

(3)幂运算相对耗时

2.2 ReLu 激活函数

卷积神经网络 CNN 中常用的激活函数是 ReLu,其数学表达式为:

ReLu ( x ) = max ⁡ { 0 , x } \text{ReLu}(x)=\max\{0,x\} ReLu(x)=max{0,x}

import numpy as np
import matplotlib.pyplot as pltdef relu(x):return np.maximum(0,x)def plot_relu():x=np.arange(-10,10,0.1)y=relu(x)plt.plot(x,y)plt.show()if __name__ == '__main__':plot_relu()  

在这里插入图片描述

调用 PyTorch 自带的函数

import torch.nn as nnm = nn.ReLU()
input = torch.randn(2)
output = m(input)
print(input)
print(output)

tensor([-0.8167, 1.2363])
tensor([0.0000, 1.2363])

优点:

(1)收敛速度比 sigmoid 和 tanh 快;(梯度不会饱和,解决了梯度消失问题)

(2)计算复杂度低,不需要进行指数运算

缺点:

(1)ReLu的输出不是zero-centered;

(2)Dead ReLU Problem(神经元坏死现象):某些神经元可能永远不会被激活,导致相应参数不会被更新(在负数部分,梯度为0)。产生这种现象的两个原因:参数初始化问题;learning rate太高导致在训练过程中参数更新太大。解决办法:采用Xavier初始化方法;以及避免将learning rate设置太大或使用adagrad等自动调节learning rate的算法。

(3)ReLu不会对数据做幅度压缩,所以数据的幅度会随着模型层数的增加不断扩张。

2.3 Leakly ReLu 激活函数

当 ReLu 输入值为负数的时候,输出值始终为 0 0 0,其一阶导数也始终为 0 0 0,这将会导致神经元不能更新参数,也就是神经元不学习了,这种现象叫“神经元坏死”,

为了解决 ReLu 函数的这个缺点,在 ReLu 函数的负半区间引入一个泄漏(leakly)值,称为 Leakly ReLu 函数,其数学表达式为:

ReLu ( x ) = max ⁡ { α x , x } \text{ReLu}(x)=\max\{\alpha x,x\} ReLu(x)=max{αx,x}

import numpy as np
import matplotlib.pyplot as pltdef leakly_relu(x):return np.array([i if i > 0 else 0.05*i for i in x ])def lea_relu_diff(x):return np.where(x > 0, 1, 0.01)x = np.arange(-10, 10, step=0.01)
y_sigma = leakly_relu(x)
y_sigma_diff = lea_relu_diff(x)
axes = plt.subplot(111)
axes.plot(x, y_sigma, label='leakly_relu')
axes.legend()
plt.show()

在这里插入图片描述

import torch.nn as nn
import torchLeakyReLU = nn.LeakyReLU(negative_slope=5e-2)
x = torch.randn(10)
value = LeakyReLU(x)
print(x)
print(value)

tensor([ 1.3149, 0.0643, 0.5569, -0.4575, 1.6295, -0.2836, -0.8015, 1.0364,
0.3108, 0.8266])
tensor([ 1.3149, 0.0643, 0.5569, -0.0229, 1.6295, -0.0142, -0.0401, 1.0364,
0.3108, 0.8266])

Leaky ReLU函数的特点:

  • Leaky ReLU函数通过把 x x x的非常小的线性分量给予负输入 0.01 x 0.01x 0.01x来调整负值的零梯度问题。
  • Leaky有助于扩大ReLU函数的范围,通常 α \alpha α的值为0.01左右。
  • Leaky ReLU的函数范围是负无穷到正无穷。

2. 常用损失函数

2.1 均方误差损失函数

loss ( x , y ) = 1 n ∥ x − y ∥ 2 2 = 1 n ∑ i = 1 n ( x i − y i ) 2 \text{loss}(\boldsymbol{x},\boldsymbol{y})=\frac{1}{n}\Vert\boldsymbol{x}-\boldsymbol{y}\Vert_2^2=\frac{1}{n}\sum_{i=1}^n(x_i-y_i)^2 loss(x,y)=n1xy22=n1i=1n(xiyi)2

import torchinput = torch.tensor([1.0, 2.0, 3.0, 4.0])
target = torch.tensor([4.0, 5.0, 6.0, 7.0])loss_fn = torch.nn.MSELoss(reduction='mean')
loss = loss_fn(input, target)
print(loss)

tensor(9.)

2.2 L1范数误差损失函数

loss ( x , y ) = 1 n ∥ x − y ∥ 1 = 1 n ∑ i = 1 n ∣ x i − y i ∣ \text{loss}(\boldsymbol{x},\boldsymbol{y})=\frac{1}{n}\Vert\boldsymbol{x}-\boldsymbol{y}\Vert_1=\frac{1}{n}\sum_{i=1}^n\vert x_i-y_i\vert loss(x,y)=n1xy1=n1i=1nxiyi

import torchloss = torch.nn.L1Loss(reduction='mean')
input = torch.tensor([1.0, 2.0, 3.0, 4.0])
target = torch.tensor([4.0, 5.0, 6.0, 7.0])
output = loss(input, target)
print(output)

tensor(3.)

2.3 交叉熵损失函数

h ( p , q ) = − ∑ x n p ( x ) ∗ log ⁡ q ( x ) h(p,q)=-\sum_{x}^np( x)*\log q(x) h(p,q)=xnp(x)logq(x)

import torchentroy = torch.nn.CrossEntropyLoss()
input = torch.Tensor([[-0.1181, -0.3682, -0.2209]])
target = torch.tensor([0])output = entroy(input, target)
print(output)

tensor(0.9862)

3. 优化器

import torch
import torch.nn
import torch.utils.data as Data
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"matplotlib.rcParams['font.sans-serif'] = ['SimHei']#准备建模数据
x = torch.unsqueeze(torch.linspace(-1, 1, 500), dim=1)
y = x.pow(3)#设置超参数
LR = 0.01
batch_size = 15
epoches = 5
torch.manual_seed(10)#设置数据加载器
dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True,num_workers=2)#搭建神经网络
class Net(torch.nn.Module):def __init__(self, n_input, n_hidden, n_output):super(Net, self).__init__()self.hidden_layer = torch.nn.Linear(n_input, n_hidden)self.output_layer = torch.nn.Linear(n_hidden, n_output)def forward(self, input):x = torch.relu(self.hidden_layer(input))output = self.output_layer(x)return output#训练模型并输出折线图
def train():net_SGD = Net(1, 10, 1)net_Momentum = Net(1, 10, 1)net_AdaGrad = Net(1, 10, 1)net_RMSprop = Net(1, 10, 1)net_Adam = Net(1, 10, 1)nets = [net_SGD, net_Momentum, net_AdaGrad, net_RMSprop, net_Adam]#定义优化器optimizer_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)optimizer_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.6)optimizer_AdaGrad = torch.optim.Adagrad(net_AdaGrad.parameters(), lr=LR, lr_decay=0)optimizer_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)optimizer_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))optimizers = [optimizer_SGD, optimizer_Momentum, optimizer_AdaGrad, optimizer_RMSprop, optimizer_Adam]#定义损失函数loss_function = torch.nn.MSELoss()losses = [[], [], [], [], []]for epoch in range(epoches):for step, (batch_x, batch_y) in enumerate(loader):for net, optimizer, loss_list in zip(nets, optimizers, losses):pred_y = net(batch_x)loss = loss_function(pred_y, batch_y)optimizer.zero_grad()loss.backward()optimizer.step()loss_list.append(loss.data.numpy())plt.figure(figsize=(12,7))labels = ['SGD', 'Momentum', 'AdaGrad', 'RMSprop', 'Adam']for i, loss in enumerate(losses):plt.plot(loss, label=labels[i])plt.legend(loc='upper right',fontsize=15)plt.tick_params(labelsize=13)plt.xlabel('Train Step',size=15)plt.ylabel('Loss',size=15)plt.ylim((0, 0.3))plt.show()if __name__ == "__main__":train()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/353433.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

统计linux日志中请求被拒绝的ip

grep -oP dial tcp \K\S(?:[[:digit:]]) log.log.2023-06-03 | sort | uniq 结果:

NRF52832空中升级DFU

1.工具环境搭建 gcc-arm-none-eabi编译环境:GCC编译环境 Downloads | GNU Arm Embedded Toolchain Downloads – Arm Developer mingw 平台(win版的Linux命令行) Download MinGW - Minimalist GNU for Windows micro-ecc-master源码 GitHu…

hash传递攻击

简介 Pass the hash也就是Hash传递攻击,简称为PTH。模拟用户登录不需要用户明文密码只需要hash值就可以直接来登录目标系统。 利用前提条件是: 开启445端口开启ipc$共享 Metasploit pesexec模块 windows/smb/psexec 这里主要设置smbuser、smbPass …

月互联网十大热词出路 世博庙会、天上人间入选

世 博庙会、世 界杯家规、伪娘、京 十二条、菜奴、天 上人间、词媒体、零薪族、张悟本 世 博庙会 世博庙会是互动百科的智愿者针对上海世博会里游人熙熙攘攘排队、热热闹闹拍照、匆匆忙忙盖章的“走马观花”式的逛庙会游览方式的一种形象说法。世 界杯家规 世界杯家规&#xf…

2011最新整理分享平台代码参考

平台代码URL115收藏夹115http://fav.115.com139社区139http://www.139.com139邮箱139mailhttp://mail.10086.cn42区42quhttp://42qu.com5151http://www.51.com淘男网51taonanhttp://www.51taonan.com豆瓣9点9dianhttp://9.douban.com/就喜欢网9favhttp://www.9fav.comAsk.comas…

Web的基本漏洞--越权漏洞

目录 一、越权漏洞介绍 1.越权漏洞的原理 2.越权漏洞的分类 3.越权漏洞产生的原因 一、越权漏洞介绍 越权漏洞指的是应用在检查授权时存在纰漏,可以让攻击者获得低权限用户账户后,利用一些方式绕过权限检查,可以访问或者操作其他用户或者…

DevOps该怎么做?

年初在家待了一段时间看了两本书收获还是挺多的. 这些年一直忙于项目, 经历了软件项目的每个阶段, 多多少少知道每个阶段是个什么, 会做哪些事情浮于表面, 没有深入去思考每个阶段背后的理论基础, 最佳实践和落地工具. 某天leader说你书看完了, 只有笔记没有总结, 你就写个总结…

Java009——Java数据类型简单认识

围绕以下3点学习: 1、什么是Java数据类型? 2、Java数据类型的作用? 3、Java有哪些数据类型? 4、熟悉Java8大基本数据类型 一、什么是Java数据类型? 当我们写Java代码时,需要把数据保存在变量(…

win7系统卸载360管家之后无法上网怎么回事?

win7系统卸载360管家之后无法上网怎么回事?有用户电脑安装的360管家软件卸载了之后,发现网络连接也同时出现了错误。那么这个情况是什么原因呢?接下来一起来看看如何解决因为卸载360软件而导致的电脑无法上网问题的解决方法吧。 解决方法 1、…

360天擎无密码退出和卸载

无密码退出 删除配置(360Safe\EntClient\conf\EntBase.dat) **注:**删除时提示权限不够无法删除,使用360自带的文件粉碎机将文件粉碎即可无密码退出360天擎(重新打开程序时EntBase.dat会再次生成) 修改配置…

idesk卸载教程_【亲测可行】Autodesk 卸载工具,一键完全彻底卸载删除autodesk软件专门卸载工具...

autodesk卸载工具(AUTO Uninstaller)是专门为了针对autodesk类软件卸载不干净而导致autodesk安装失败问题进行研发的autodesk一键卸载工具。现在虽然360或一些卸载软件提供了强力卸载autodesk的工具,可以将autodesk注册表和一些autodesk目录的autodesk残留信息删除&…

关闭/卸载360画报/壁纸

前言 360画报和360壁纸指的是同一个组件,只是称呼不同,下文简称360画报,是360安全卫士和360安全浏览器自动安装的组件之一,是可以被关闭和卸载的。此组件会自动开启360屏保,用户不喜欢可以进行关闭或卸载。下文介绍关…

mysql删除工具_有没有mysql卸载工具

展开全部 没有工具,手工清理就636f707962616964757a686964616f31333365636633可以了。 1.打开命令行,运行命令 net start,查看一下mysql服务,如果开启就使用命令 net stop mysql 将其关闭。 注:如果提示无法关闭&#…

【亲测可行】Autodesk 卸载工具,一键完全彻底卸载删除autodesk软件专门卸载工具...

autodesk卸载工具(AUTO Uninstaller)是专门为了针对autodesk类软件卸载不干净而导致autodesk安装失败问题进行研发的autodesk一键卸载工具。现在虽然360或一些卸载软件提供了强力卸载autodesk的工具,可以将autodesk注册表和一些autodesk目录的autodesk残留信息删除&…

Autodesk 卸载工具

autodesk卸载工具(AUTO Uninstaller)是专门为了针对autodesk类软件卸载不干净而导致autodesk安装失败问题进行研发的autodesk一键卸载工具。现在虽然360或一些卸载软件提供了强力卸载autodesk的工具,可以将autodesk注册表和一些autodesk目录的autodesk残留信息删除&…

笔试强训9

作者:爱塔居 专栏:笔试强训 文章简介:简单记录学习的细碎~ day15 一. 单选 1.给出数据表 score(stu-id,name,math,english,Chinese), 下列语句正确的是( ) A Select sum(math),avg(chinese) from score B Select *,s…

解决360卸载之后遗留问题:windows defender无法开启

前几日,在对一台新电脑进行”净化工作“——卸载很多原装的垃圾软件,卸载了360之后发现windows defender无法打开,找到services.msc无法开启,启动按钮是灰色的,在查看了很多的教程之后,并确认windows defen…

一篇文章搞懂CMake(gcc、g++、cmake解释)

一篇文章搞懂CMake (gcc、g、cmake解释) 这里写目录标题 一篇文章搞懂CMake (gcc、g、cmake解释)gccgcmake1. CMake 流程如何使用cmake?简单的CMake.txt文件 参考 gcc gcc命令来自英文词组“GNU Compiler Collection”…

如何关闭计算机软件更新功能,如何关闭电脑自动更新功能

大家好,我是时间财富网智能客服时间君,上述问题将由我为大家进行解答。 以Windows 7电脑为例,关闭电脑自动更新功能的方法: 1、首先按下winr打开运行窗口。 2、接着输入services.msc并点击回车。 3、然后选择Windows update选项。…

关闭windows自动更新

1、win r ---- 输入services.msc进入服务 找到Windows Update 双击打开属性弹框 将启动类型改为禁用,并在恢复中将第一次失败改成无操作 2、win r ---- 输入gpedit.msc进入本地组策略编辑器 找到计算机配置 》管理模板 》Windows组件 》Windows更新,选…